From f3e7d03d93b30a79a37581108067e6bb8f428c94 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 22 Mar 2024 14:07:47 +0100 Subject: [PATCH 01/55] Added function get_img_at_mpp to class OpenSlideWSIReader of module wsi_reader.py --- monai/data/wsi_reader.py | 82 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index b31d4d9c3a..f3f099160f 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -19,6 +19,7 @@ import numpy as np import torch +import cv2 from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images @@ -940,6 +941,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + mpp_closest_lvl = mpp_list[closest_lvl] + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + return closest_lvl_wsi + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + return closest_lvl_wsi + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + return closest_lvl_wsi + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From 88002e8a91d6466a2fdcb60a19b8cd1ed9e89558 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 22 Mar 2024 16:42:18 +0100 Subject: [PATCH 02/55] Added get_img_at_mpp to class CuCIMWSIReader --- monai/data/wsi_reader.py | 98 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index f3f099160f..0c49143220 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,6 +603,23 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) + + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + return self.reader.get_img_at_mpp(wsi, mpp, atol, rtol) def get_power(self, wsi, level: int) -> float: """ @@ -745,6 +762,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + mpp_closest_lvl = mpp_list[closest_lvl] + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + return closest_lvl_wsi + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + return closest_lvl_wsi + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + return closest_lvl_wsi + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From a9fe772d56a458c853bbb69ecef4585cb2ef564d Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 19:18:42 +0100 Subject: [PATCH 03/55] Added function get_img_at_mpp to class TifffileWSIReader; changed resizing function to Image.resize, cucim.skimage.transform.resize --- monai/data/wsi_reader.py | 160 +++++++++++++++++++++++++++++++-------- 1 file changed, 130 insertions(+), 30 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 0c49143220..4f02cee285 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -19,7 +19,6 @@ import numpy as np import torch -import cv2 from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images @@ -778,9 +777,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ + cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + cp, _ = optional_import("cupy") + user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] @@ -797,13 +801,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - return closest_lvl_wsi else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x @@ -814,15 +817,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - return closest_lvl_wsi + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -833,15 +837,18 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_y = mpp_closest_lvl_y / user_mpp_y closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') - return closest_lvl_wsi + + wsi_arr = cp.asnumpy(closest_lvl_wsi) + return wsi_arr def get_power(self, wsi, level: int) -> float: """ @@ -1055,9 +1062,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ + pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.level_dimensions[closest_lvl] @@ -1078,9 +1088,8 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - return closest_lvl_wsi else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x @@ -1091,15 +1100,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) - + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - return closest_lvl_wsi + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -1110,15 +1118,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_y = mpp_closest_lvl_y / user_mpp_y closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) - + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') - return closest_lvl_wsi + + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr def get_power(self, wsi, level: int) -> float: """ @@ -1276,8 +1285,10 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: - unit = str(unit.value)[8:] + if unit is not None: # Needs to be extended + # unit = str(unit.value)[8:] + unit = str(unit.value.name).lower() # TODO: Merge both methods + else: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") unit = "micrometer" @@ -1290,6 +1301,95 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + pil_image, _ = optional_import("PIL", name="Image") + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # QuPath show 4 levels in the pyramid, but len(wsi.pages) is 1? + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + + mpp_closest_lvl = mpp_list[closest_lvl] + + lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] # Returns size in (height, width) + closest_lvl_dim = lvl_dims[closest_lvl] + closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = lvl_dims[closest_lvl] + closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From feac0dc57ef2cc6fd3f24204d43881fd186c5595 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 23:17:37 +0100 Subject: [PATCH 04/55] Small changes --- monai/data/wsi_reader.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 4f02cee285..3f2e26f9e2 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -607,8 +607,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -765,8 +767,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -786,7 +790,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] # x,y notation print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -805,7 +809,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) # size in x,y notation else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -823,8 +827,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) # output_shape in row, col notation print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') else: @@ -843,7 +846,6 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') @@ -1050,8 +1052,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -1123,7 +1127,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) # Output size in x,y notation print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') wsi_arr = np.array(closest_lvl_wsi) @@ -1305,7 +1309,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + If the user-provided mpp is larger than the mpp of the closest level the image is downscaled to a resolution that matches the user-provided mpp. Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. @@ -1319,8 +1323,8 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # QuPath show 4 levels in the pyramid, but len(wsi.pages) is 1? - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] @@ -1358,7 +1362,6 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) @@ -1378,7 +1381,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) From 81940261f2c35ae5f8c4485054523b9306a3eeff Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 23:21:57 +0100 Subject: [PATCH 05/55] Small changes --- monai/data/wsi_reader.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 3f2e26f9e2..9c6ee9c387 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -790,7 +790,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] # x,y notation + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -809,7 +809,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) # size in x,y notation + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -827,13 +827,13 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) # output_shape in row, col notation + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl ds_factor_x = mpp_closest_lvl_x / user_mpp_x @@ -1115,7 +1115,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl ds_factor_x = mpp_closest_lvl_x / user_mpp_x @@ -1127,7 +1127,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) # Output size in x,y notation + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') wsi_arr = np.array(closest_lvl_wsi) @@ -1289,9 +1289,9 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: # Needs to be extended - # unit = str(unit.value)[8:] - unit = str(unit.value.name).lower() # TODO: Merge both methods + if unit is not None: # Needs to be improved + unit = str(unit.value)[8:] + # unit = str(unit.value.name).lower() # TODO: Merge both methods else: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") @@ -1309,8 +1309,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -1329,7 +1331,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_closest_lvl = mpp_list[closest_lvl] - lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] # Returns size in (height, width) + lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) From 4df0b4b61e5fded01369ea531782b4df01bac3ba Mon Sep 17 00:00:00 2001 From: cxlcl Date: Fri, 22 Mar 2024 09:54:40 -0700 Subject: [PATCH 06/55] Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient (#7308) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Based on the discussion topic [here](https://github.com/Project-MONAI/MONAI/discussions/7161#discussion-5773293), we implemented the Conjugate-Gradient algorithm for linear operator inversion, and Stein's Unbiased Risk Estimator (SURE) [1] loss for ground-truth-date free diffusion process guidance that is proposed in [2] and illustrated in the algorithm below: Screenshot 2023-12-10 at 10 19 25 PM The Conjugate-Gradient (CG) algorithm is used to solve for the inversion of the linear operator in Line-4 in the algorithm above, where the linear operator is too large to store explicitly as a matrix (such as FFT/IFFT of an image) and invert directly. Instead, we can solve for the linear inversion iteratively as in CG. The SURE loss is applied for Line-6 above. This is a differentiable loss function that can be used to train/giude an operator (e.g. neural network), where the pseudo ground truth is available but the reference ground truth is not. For example, in the MRI reconstruction, the pseudo ground truth is the zero-filled reconstruction and the reference ground truth is the fully sampled reconstruction. The reference ground truth is not available due to the lack of fully sampled. **Reference** [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics 1981 [[paper link](https://projecteuclid.org/journals/annals-of-statistics/volume-9/issue-6/Estimation-of-the-Mean-of-a-Multivariate-Normal-Distribution/10.1214/aos/1176345632.full)] [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. MICCAI 2023 [[paper link](https://arxiv.org/pdf/2310.01799.pdf)] ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Nikolas Schmitz --- docs/source/losses.rst | 5 + docs/source/networks.rst | 5 + monai/losses/__init__.py | 1 + monai/losses/sure_loss.py | 200 ++++++++++++++++++++ monai/networks/layers/__init__.py | 1 + monai/networks/layers/conjugate_gradient.py | 112 +++++++++++ tests/test_conjugate_gradient.py | 55 ++++++ tests/test_sure_loss.py | 71 +++++++ 8 files changed, 450 insertions(+) create mode 100644 monai/losses/sure_loss.py create mode 100644 monai/networks/layers/conjugate_gradient.py create mode 100644 tests/test_conjugate_gradient.py create mode 100644 tests/test_sure_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 61dd959807..ba794af3eb 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -139,6 +139,11 @@ Reconstruction Losses .. autoclass:: JukeboxLoss :members: +`SURELoss` +~~~~~~~~~~ +.. autoclass:: SURELoss + :members: + Loss Wrappers ------------- diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..b59c8af5fc 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -408,6 +408,11 @@ Layers .. autoclass:: LLTM :members: +`ConjugateGradient` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConjugateGradient + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 4ebedb2084..e937b53fa4 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -41,5 +41,6 @@ from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss +from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py new file mode 100644 index 0000000000..ebf25613a6 --- /dev/null +++ b/monai/losses/sure_loss.py @@ -0,0 +1,200 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + + +def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + First compute the difference in the complex domain, + then get the absolute value and take the mse + + Args: + x, y - B, 2, H, W real valued tensors representing complex numbers + or B,1,H,W complex valued tensors + Returns: + l2_loss - scalar + """ + if not x.is_complex(): + x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) + if not y.is_complex(): + y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) + + diff = torch.abs(x - y) + return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean") + + +def sure_loss_function( + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + eps: Optional[float] = -1.0, + perturb_noise: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, +) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. For complex input, the shape is + (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) + real. + + y_ref (torch.Tensor, optional): The reference output tensor of shape + (B, C, H, W) used to compute the divergence. Defaults to None. For + complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, + the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Set to -1 to set it + automatically estimated based on y_pseudo_gtk + + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). + Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + complex_input(bool, optional): Whether the input is complex or not. + Defaults to False. + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + # perturb input + if perturb_noise is None: + perturb_noise = torch.randn_like(x) + if eps == -1.0: + eps = float(torch.abs(y_pseudo_gt.max())) / 1000 + # get y_ref if not provided + if y_ref is None: + y_ref = operator(x) + + # get perturbed output + x_perturbed = x + eps * perturb_noise + y_perturbed = operator(x_perturbed) + # divergence + divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore + # l2 loss between y_ref, y_pseudo_gt + if complex_input: + l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) + else: + # real input + l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean") + + # sure loss + sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3]) + return sure_loss + + +class SURELoss(_Loss): + """ + Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator. + + This is a differentiable loss function that can be used to train/guide an + operator (e.g. neural network), where the pseudo ground truth is available + but the reference ground truth is not. For example, in the MRI + reconstruction, the pseudo ground truth is the zero-filled reconstruction + and the reference ground truth is the fully sampled reconstruction. Often, + the reference ground truth is not available due to the lack of fully sampled + data. + + The original SURE loss is proposed in [1]. The SURE loss used for guiding + the diffusion model based MRI reconstruction is proposed in [2]. + + Reference + + [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics + + [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. + (https://arxiv.org/pdf/2310.01799.pdf) + """ + + def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None: + """ + Args: + perturb_noise (torch.Tensor, optional): The noise vector of shape + (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Defaults to None. + """ + super().__init__() + self.perturb_noise = perturb_noise + self.eps = eps + + def forward( + self, + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, + ) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka + C=2 real. For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex + input, the shape is (B, 2, H, W) aka C=2 real. For real input, the + shape is (B, 1, H, W) real. + + y_ref (torch.Tensor, optional): The reference output tensor of the + same shape as y_pseudo_gt + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + + # check inputs shapes + if x.dim() != 4: + raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.") + if y_pseudo_gt.dim() != 4: + raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.") + if y_ref is not None and y_ref.dim() != 4: + raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.") + if x.shape != y_pseudo_gt.shape: + raise ValueError( + f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, " + f"y_pseudo_gt shape {y_pseudo_gt.shape}." + ) + if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: + raise ValueError( + f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, " + f"y_ref shape {y_ref.shape}." + ) + + # compute loss + loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) + + return loss diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..3a6e4aa554 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .conjugate_gradient import ConjugateGradient from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py new file mode 100644 index 0000000000..93a45930d7 --- /dev/null +++ b/monai/networks/layers/conjugate_gradient.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable + +import torch +from torch import nn + + +def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensors x1 and x2: sum(x1.*x2) + """ + if torch.is_complex(x1): + assert torch.is_complex(x2), "x1 and x2 must both be complex" + return torch.sum(x1.conj() * x2) + else: + return torch.sum(x1 * x2) + + +def _zdot_single(x: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensor x and itself + """ + res = _zdot(x, x) + if torch.is_complex(res): + return res.real + else: + return res + + +class ConjugateGradient(nn.Module): + """ + Congugate Gradient (CG) solver for linear systems Ax = y. + + For linear_op that is positive definite and self-adjoint, CG is + guaranteed to converge CG is often used to solve linear systems of the form + Ax = y, where A is too large to store explicitly, but can be computed via a + linear operator. + + As a result, here we won't set A explicitly as a matrix, but rather as a + linear operator. For example, A could be a FFT/IFFT operation + """ + + def __init__(self, linear_op: Callable, num_iter: int): + """ + Args: + linear_op: Linear operator + num_iter: Number of iterations to run CG + """ + super().__init__() + + self.linear_op = linear_op + self.num_iter = num_iter + + def update( + self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + perform one iteration of the CG method. It takes the current solution x, + the current search direction p, the current residual r, and the old + residual norm rsold as inputs. Then it computes the new solution, search + direction, residual, and residual norm, and returns them. + """ + + dy = self.linear_op(p) + p_dot_dy = _zdot(p, dy) + alpha = rsold / p_dot_dy + x = x + alpha * p + r = r - alpha * dy + rsnew = _zdot_single(r) + beta = rsnew / rsold + rsold = rsnew + p = beta * p + r + return x, p, r, rsold + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + run conjugate gradient for num_iter iterations to solve Ax = y + + Args: + x: tensor (real or complex); Initial guess for linear system Ax = y. + The size of x should be applicable to the linear operator. For + example, if the linear operator is FFT, then x is HCHW; if the + linear operator is a matrix multiplication, then x is a vector + + y: tensor (real or complex); Measurement. Same size as x + + Returns: + x: Solution to Ax = y + """ + # Compute residual + r = y - self.linear_op(x) + rsold = _zdot_single(r) + p = r + + # Update + for _i in range(self.num_iter): + x, p, r, rsold = self.update(x, p, r, rsold) + if rsold < 1e-10: + break + return x diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py new file mode 100644 index 0000000000..239dbe3ecd --- /dev/null +++ b/tests/test_conjugate_gradient.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.layers import ConjugateGradient + + +class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): + """Test ConjugateGradient with real-valued input: when the input is real + value, the output should be the inverse of the matrix.""" + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + # define the measurement + y = torch.tensor([1, 2, 3], dtype=torch.float) + # solve for x + x = cg_solver(torch.zeros(a_dim), y) + x_ref = torch.linalg.solve(a_mat, y) + # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + def test_complex_valued_inverse(self): + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + y = torch.tensor([1, 2, 3], dtype=torch.complex64) + x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) + x_ref = torch.linalg.solve(a_mat, y) + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py new file mode 100644 index 0000000000..945da657bf --- /dev/null +++ b/tests/test_sure_loss.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.losses import SURELoss + + +class TestSURELoss(unittest.TestCase): + def test_real_value(self): + """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" + sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) + + def operator(x): + return x + + y_pseudo_gt = torch.randn(2, 1, 128, 128) + x = torch.randn(2, 1, 128, 128) + loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_value(self): + """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" + + def operator(x): + return x + + sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1) + y_pseudo_gt = torch.randn(2, 2, 128, 128) + x = torch.randn(2, 2, 128, 128) + loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_general_input(self): + """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" + + def operator(x): + return x + + perturb_noise_real = torch.randn(2, 1, 128, 128) + perturb_noise_complex = torch.zeros(2, 2, 128, 128) + perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze() + y_pseudo_gt_real = torch.randn(2, 1, 128, 128) + y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) + y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze() + x_real = torch.randn(2, 1, 128, 128) + x_complex = torch.zeros(2, 2, 128, 128) + x_complex[:, 0, :, :] = x_real.squeeze() + + sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) + sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) + + loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) + loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + + +if __name__ == "__main__": + unittest.main() From d989c18e9ae3f5d338156ef1c32da4561bf07cbb Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 00:13:12 +0100 Subject: [PATCH 07/55] Renamed function to get_wsi_at_mpp Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 9c6ee9c387..5e3b0e9d36 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,7 +603,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -620,7 +620,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 rtol: the acceptable relative tolerance for resolution in micro per pixel. """ - return self.reader.get_img_at_mpp(wsi, mpp, atol, rtol) + return self.reader.get_wsi_at_mpp(wsi, mpp, atol, rtol) def get_power(self, wsi, level: int) -> float: """ @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1048,7 +1048,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1305,7 +1305,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From 105f00b7c8c1bd1cbee5f6fdc0841b001fc1e636 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Mar 2024 23:53:13 +0000 Subject: [PATCH 08/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 5e3b0e9d36..1f036e334e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -602,7 +602,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. @@ -829,7 +829,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -1088,7 +1088,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') @@ -1326,7 +1326,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] @@ -1348,7 +1348,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') From 5db27c1e15e5b2f2db8a523572ece33edd61807a Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 02:21:13 +0100 Subject: [PATCH 09/55] Reformatted wsi_reader.py Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 42 +++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 1f036e334e..16a3150c4a 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -785,14 +785,14 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 cp, _ = optional_import("cupy") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -808,8 +808,10 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -821,14 +823,16 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) wsi_arr = cp.array(closest_lvl_wsi) target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -839,15 +843,17 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) wsi_arr = cp.array(closest_lvl_wsi) target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -1075,7 +1081,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.level_dimensions[closest_lvl] - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -1091,7 +1097,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) else: @@ -1110,7 +1116,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -1128,7 +1134,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1335,7 +1341,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -1351,7 +1357,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) else: @@ -1370,7 +1376,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -1390,7 +1396,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = np.array(closest_lvl_wsi) return wsi_arr From 18e82bd0674221331cfd82ff20785774688ce296 Mon Sep 17 00:00:00 2001 From: monai-bot <64792179+monai-bot@users.noreply.github.com> Date: Mon, 25 Mar 2024 07:26:43 +0000 Subject: [PATCH 10/55] auto updates (#7577) Signed-off-by: monai-bot Signed-off-by: monai-bot Signed-off-by: Nikolas Schmitz --- tests/test_conjugate_gradient.py | 1 + tests/test_sure_loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index 239dbe3ecd..64efe3b168 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -19,6 +19,7 @@ class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): """Test ConjugateGradient with real-valued input: when the input is real value, the output should be the inverse of the matrix.""" diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 945da657bf..903f9bd2ca 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -19,6 +19,7 @@ class TestSURELoss(unittest.TestCase): + def test_real_value(self): """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) From 5bb531e8b3ae16316b465162998072014fb50792 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 11:18:03 +0100 Subject: [PATCH 11/55] Fixed return type Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 16a3150c4a..d7cfb444e3 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,7 +603,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1054,7 +1054,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From 5214c56241509fa447e3cb4a8a59a515e287a8fe Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 12:07:06 +0100 Subject: [PATCH 12/55] Small fixes Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index d7cfb444e3..be121efa40 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> Any: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1311,7 +1311,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From 3f055a9022d027386566e2e94420f997360988da Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Mon, 25 Mar 2024 22:13:56 -0400 Subject: [PATCH 13/55] Remove nested error propagation on `ConfigComponent` instantiate (#7569) Fixes #7451 ### Description Reduces the length of error messages and error messages being propagated twice. This helps debug better when long `ConfigComponent`s are being instantiated. Refer to issue #7451 for more details ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Suraj Pai Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/bundle/config_item.py | 5 +---- monai/utils/module.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 844d5b30bf..e5122bf3de 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -289,10 +289,7 @@ def instantiate(self, **kwargs: Any) -> object: mode = self.get_config().get("_mode_", CompInitMode.DEFAULT) args = self.resolve_args() args.update(kwargs) - try: - return instantiate(modname, mode, **args) - except Exception as e: - raise RuntimeError(f"Failed to instantiate {self}") from e + return instantiate(modname, mode, **args) class ConfigExpression(ConfigItem): diff --git a/monai/utils/module.py b/monai/utils/module.py index 5e058c105b..6f301d8067 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -272,7 +272,7 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: return pdb.runcall(component, **kwargs) except Exception as e: raise RuntimeError( - f"Failed to instantiate component '{__path}' with kwargs: {kwargs}" + f"Failed to instantiate component '{__path}' with keywords: {','.join(kwargs.keys())}" f"\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode." ) from e From 3264079906ffa02055c80eb427f7157fd398b151 Mon Sep 17 00:00:00 2001 From: Juampa <1523654+juampatronics@users.noreply.github.com> Date: Tue, 26 Mar 2024 03:57:36 +0100 Subject: [PATCH 14/55] 2872 implementation of mixup, cutmix and cutout (#7198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #2872 ### Description Implementation of mixup, cutmix and cutout as described in the original papers. Current implementation support both, the dictionary-based batches and tuples of tensors. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: monai-bot Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Ishan Dutta Signed-off-by: dependabot[bot] Signed-off-by: kaibo Signed-off-by: heyufan1995 Signed-off-by: binliu Signed-off-by: axel.vlaminck Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ishan Dutta Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: axel.vlaminck Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl --- docs/source/transforms.rst | 42 +++++ docs/source/transforms_idx.rst | 10 + monai/transforms/__init__.py | 12 ++ monai/transforms/regularization/__init__.py | 10 + monai/transforms/regularization/array.py | 173 ++++++++++++++++++ monai/transforms/regularization/dictionary.py | 97 ++++++++++ tests/test_regularization.py | 90 +++++++++ 7 files changed, 434 insertions(+) create mode 100644 monai/transforms/regularization/__init__.py create mode 100644 monai/transforms/regularization/array.py create mode 100644 monai/transforms/regularization/dictionary.py create mode 100644 tests/test_regularization.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8990e7991d..bd3feb3497 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -661,6 +661,27 @@ Post-processing :members: :special-members: __call__ +Regularization +^^^^^^^^^^^^^^ + +`CutMix` +"""""""" +.. autoclass:: CutMix + :members: + :special-members: __call__ + +`CutOut` +"""""""" +.. autoclass:: CutOut + :members: + :special-members: __call__ + +`MixUp` +""""""" +.. autoclass:: MixUp + :members: + :special-members: __call__ + Signal ^^^^^^^ @@ -1707,6 +1728,27 @@ Post-processing (Dict) :members: :special-members: __call__ +Regularization (Dict) +^^^^^^^^^^^^^^^^^^^^^ + +`CutMixd` +""""""""" +.. autoclass:: CutMixd + :members: + :special-members: __call__ + +`CutOutd` +""""""""" +.. autoclass:: CutOutd + :members: + :special-members: __call__ + +`MixUpd` +"""""""" +.. autoclass:: MixUpd + :members: + :special-members: __call__ + Signal (Dict) ^^^^^^^^^^^^^ diff --git a/docs/source/transforms_idx.rst b/docs/source/transforms_idx.rst index f4d02a483f..650d45db71 100644 --- a/docs/source/transforms_idx.rst +++ b/docs/source/transforms_idx.rst @@ -74,6 +74,16 @@ Post-processing post.array post.dictionary +Regularization +^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: _gen + :nosignatures: + + regularization.array + regularization.dictionary + Signal ^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2aa8fbf8a1..349533fb3e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -336,6 +336,18 @@ VoteEnsembled, VoteEnsembleDict, ) +from .regularization.array import CutMix, CutOut, MixUp +from .regularization.dictionary import ( + CutMixd, + CutMixD, + CutMixDict, + CutOutd, + CutOutD, + CutOutDict, + MixUpd, + MixUpD, + MixUpDict, +) from .signal.array import ( SignalContinuousWavelet, SignalFillEmpty, diff --git a/monai/transforms/regularization/__init__.py b/monai/transforms/regularization/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/regularization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py new file mode 100644 index 0000000000..6c9022d647 --- /dev/null +++ b/monai/transforms/regularization/array.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from math import ceil, sqrt + +import torch + +from ..transform import RandomizableTransform + +__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] + + +class Mixer(RandomizableTransform): + def __init__(self, batch_size: int, alpha: float = 1.0) -> None: + """ + Mixer is a base class providing the basic logic for the mixup-class of + augmentations. In all cases, we need to sample the mixing weights for each + sample (lambda in the notation used in the papers). Also, pairs of samples + being mixed are picked by randomly shuffling the batch samples. + + Args: + batch_size (int): number of samples per batch. That is, samples are expected tp + be of size batchsize x channels [x depth] x height x width. + alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha) + distribution. Defaults to 1.0, the uniform distribution. + """ + super().__init__() + if alpha <= 0: + raise ValueError(f"Expected positive number, but got {alpha = }") + self.alpha = alpha + self.batch_size = batch_size + + @abstractmethod + def apply(self, data: torch.Tensor): + raise NotImplementedError() + + def randomize(self, data=None) -> None: + """ + Sometimes you need may to apply the same transform to different tensors. + The idea is to get a sample and then apply it with apply() as often + as needed. You need to call this method everytime you apply the transform to a new + batch. + """ + self._params = ( + torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), + self.R.permutation(self.batch_size), + ) + + +class MixUp(Mixer): + """MixUp as described in: + Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz. + mixup: Beyond Empirical Risk Minimization, ICLR 2018 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. + """ + + def apply(self, data: torch.Tensor): + weight, perm = self._params + nsamples, *dims = data.shape + if len(weight) != nsamples: + raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") + + if len(dims) not in [3, 4]: + raise ValueError("Unexpected number of dimensions") + + mixweight = weight[(Ellipsis,) + (None,) * len(dims)] + return mixweight * data + (1 - mixweight) * data[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + if labels is None: + return self.apply(data) + return self.apply(data), self.apply(labels) + + +class CutMix(Mixer): + """CutMix augmentation as described in: + Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo. + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, + ICCV 2019 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles used during for mixing. + Please refer to the paper for details. + + The most common use case is something close to: + + .. code-block:: python + + cm = CutMix(batch_size=8, alpha=0.5) + for batch in loader: + images, labels = batch + augimg, auglabels = cm(images, labels) + output = model(augimg) + loss = loss_function(output, auglabels) + ... + + """ + + def apply(self, data: torch.Tensor): + weights, perm = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + (1 - mask) * data[perm, ...] + + def apply_on_labels(self, labels: torch.Tensor): + weights, perm = self._params + nsamples, *dims = labels.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mixweight = weights[(Ellipsis,) + (None,) * len(dims)] + return mixweight * labels + (1 - mixweight) * labels[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + augmented = self.apply(data) + return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + + +class CutOut(Mixer): + """Cutout as described in the paper: + Terrance DeVries, Graham W. Taylor. + Improved Regularization of Convolutional Neural Networks with Cutout, + arXiv:1708.04552 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles being cut put. + Please refer to the paper for details. + """ + + def apply(self, data: torch.Tensor): + weights, _ = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + + def __call__(self, data: torch.Tensor): + self.randomize() + return self.apply(data) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py new file mode 100644 index 0000000000..373913da99 --- /dev/null +++ b/monai/transforms/regularization/dictionary.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.config import KeysCollection +from monai.utils.misc import ensure_tuple + +from ..transform import MapTransform +from .array import CutMix, CutOut, MixUp + +__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] + + +class MixUpd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.MixUp`. + + Notice that the mixup transformation will be the same for all entries + for consistency, i.e. images and labels must be applied the same augmenation. + """ + + def __init__( + self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixup = MixUp(batch_size, alpha) + + def __call__(self, data): + self.mixup.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixup.apply(data[k]) + return result + + +class CutMixd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutMix`. + + Notice that the mixture weights will be the same for all entries + for consistency, i.e. images and labels must be aggregated with the same weights, + but the random crops are not. + """ + + def __init__( + self, + keys: KeysCollection, + batch_size: int, + label_keys: KeysCollection | None = None, + alpha: float = 1.0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixer = CutMix(batch_size, alpha) + self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] + + def __call__(self, data): + self.mixer.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixer.apply(data[k]) + for k in self.label_keys: + result[k] = self.mixer.apply_on_labels(data[k]) + return result + + +class CutOutd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutOut`. + + Notice that the cutout is different for every entry in the dictionary. + """ + + def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.cutout = CutOut(batch_size) + + def __call__(self, data): + result = dict(data) + self.cutout.randomize() + for k in self.keys: + result[k] = self.cutout(data[k]) + return result + + +MixUpD = MixUpDict = MixUpd +CutMixD = CutMixDict = CutMixd +CutOutD = CutOutDict = CutOutd diff --git a/tests/test_regularization.py b/tests/test_regularization.py new file mode 100644 index 0000000000..d381ea72ca --- /dev/null +++ b/tests/test_regularization.py @@ -0,0 +1,90 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd + + +class TestMixup(unittest.TestCase): + def test_mixup(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup = MixUp(6, 1.0) + output = mixup(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + + with self.assertRaises(ValueError): + MixUp(6, -0.5) + + mixup = MixUp(6, 0.5) + for dims in [2, 3]: + with self.assertRaises(ValueError): + shape = (5, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup(sample) + + def test_mixupd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + mixup = MixUpd(["a", "b"], 6) + output = mixup(sample) + self.assertTrue(torch.allclose(output["a"], output["b"])) + + with self.assertRaises(ValueError): + MixUpd(["k1", "k2"], 6, -0.5) + + +class TestCutMix(unittest.TestCase): + def test_cutmix(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutmix = CutMix(6, 1.0) + output = cutmix(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) + + def test_cutmixd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + label = torch.randint(0, 1, shape) + sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} + cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + output = cutmix(sample) + # croppings are different on each application + self.assertTrue(not torch.allclose(output["a"], output["b"])) + # but mixing of labels is not affected by it + self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) + + +class TestCutOut(unittest.TestCase): + def test_cutout(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutout = CutOut(6, 1.0) + output = cutout(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) + + +if __name__ == "__main__": + unittest.main() From 6fcc4a6995a012fca2d6a7928ec0ff64ce9672c3 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Wed, 31 Jul 2024 19:29:47 +0200 Subject: [PATCH 15/55] Updated function get_wsi_at_mpp; added function _resize_to_mpp_res to reduce redundancy; for get_mpp of TiffFileWSIReader: added check to prevent division by zero error. --- monai/data/wsi_reader.py | 236 ++++++++++++++++++++------------------- 1 file changed, 122 insertions(+), 114 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index be121efa40..7b8e3b12d1 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -781,7 +781,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + # cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") user_mpp_x, user_mpp_y = mpp @@ -789,11 +789,10 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + # closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + # mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -808,9 +807,8 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + (0, 0), level=closest_lvl, size=wsi.resolutions["level_dimensions"][closest_lvl], num_workers=self.num_workers ) else: @@ -820,40 +818,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers - ) - wsi_arr = cp.array(closest_lvl_wsi) - - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers - ) - wsi_arr = cp.array(closest_lvl_wsi) - - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -941,6 +911,36 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + cp, _ = optional_import("cupy") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @require_pkg(pkg_name="openslide") @@ -1072,17 +1072,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -1097,8 +1092,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=wsi.level_dimensions[closest_lvl] + ) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -1107,34 +1103,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1208,6 +1182,34 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @require_pkg(pkg_name="tifffile") @@ -1295,21 +1297,27 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: # Needs to be improved - unit = str(unit.value)[8:] - # unit = str(unit.value.name).lower() # TODO: Merge both methods + if unit is not None: # Test with more tiff files + if isinstance(unit.value, int): + unit = str(unit.value.name).lower() + else: + unit = str(unit.value)[8:] else: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") unit = "micrometer" convert_to_micron = ConvertUnits(unit, "micrometer") - # Here x and y resolutions are rational numbers so each of them is represented by a tuple. + + # Here, x and y resolutions are rational numbers so each of them is represented by a tuple. yres = wsi.pages[level].tags["YResolution"].value xres = wsi.pages[level].tags["XResolution"].value - return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) - - raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + if xres[0] & yres[0]: + return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) + else: + raise ValueError("The `XResolution` and/or `YResolution` property of the image is zero, " + "which is needed to obtain `mpp` for this file. Please use `level` instead.") + raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ @@ -1331,18 +1339,15 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - - lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] - closest_lvl_dim = lvl_dims[closest_lvl] - closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + # lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] + # closest_lvl_dim = lvl_dims[closest_lvl] + # closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -1357,8 +1362,8 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + print('Tifffile, within tolerance') + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -1367,36 +1372,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: - # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = lvl_dims[closest_lvl] - closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1445,7 +1425,7 @@ def _get_patch( Extracts and returns a patch image form the whole slide image. Args: - wsi: a whole slide image object loaded from a file or a lis of such objects + wsi: a whole slide image object loaded from a file or a list of such objects location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). If None, it is set to the full image size at the given level. @@ -1477,3 +1457,31 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = self.get_size(wsi, closest_lvl) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From 4b0c9baf6a31b4fac7056e79381dc62a85809fef Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Wed, 31 Jul 2024 19:54:29 +0200 Subject: [PATCH 16/55] Minor fixes: removed unnecessary comments --- monai/data/wsi_reader.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 7b8e3b12d1..7df8256ad6 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1337,16 +1337,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - # lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] - # closest_lvl_dim = lvl_dims[closest_lvl] - # closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level @@ -1362,7 +1357,6 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print('Tifffile, within tolerance') closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) else: From 66508e92506b315fe745a4664d6cbe5a6763d2cc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 22:20:38 +0000 Subject: [PATCH 17/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 53048c2d70..1ba799f095 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -911,17 +911,17 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") @@ -1182,17 +1182,17 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") @@ -1447,11 +1447,11 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. From 441b4629f8905dba1fecbc16bb303c0c86f7ff17 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 4 Aug 2024 17:44:48 +0200 Subject: [PATCH 18/55] Added function _compute_mpp_target_res to BaseWSIReader --- monai/data/wsi_reader.py | 76 ++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 7df8256ad6..c6e2b67914 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,6 +431,28 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + return target_res_x, target_res_y + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -911,7 +933,7 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). @@ -926,20 +948,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @@ -1182,7 +1197,7 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). @@ -1196,19 +1211,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @@ -1297,13 +1306,10 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and wsi.pages[level].tags["YResolution"].value ): unit = wsi.pages[level].tags.get("ResolutionUnit") - if unit is not None: # Test with more tiff files - if isinstance(unit.value, int): - unit = str(unit.value.name).lower() - else: - unit = str(unit.value)[8:] - else: + if unit is not None: + unit = str(unit.value.name) + if unit is None or len(unit) == 0: warnings.warn("The resolution unit is missing. `micrometer` will be used as default.") unit = "micrometer" @@ -1451,7 +1457,7 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). @@ -1461,21 +1467,15 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = self.get_size(wsi, closest_lvl) - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From d73d739de08101dc3781f3cf06cd862a8a775e16 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 4 Aug 2024 17:53:19 +0200 Subject: [PATCH 19/55] Added new feature and merged updates from main repository --- monai/data/wsi_reader.py | 61 +++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 1ba799f095..e217f41c7e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,6 +431,28 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + return target_res_x, target_res_y + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -926,20 +948,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @@ -1196,19 +1211,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @@ -1457,21 +1466,15 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = self.get_size(wsi, closest_lvl) - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From 5461801547e48d941d20c876729da15769868eb1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:36:01 +0000 Subject: [PATCH 20/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index e217f41c7e..81afafb246 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -434,13 +434,13 @@ def get_data( def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -1466,7 +1466,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") From 59683bc07cef236bbe4a61129fa6f963533f217c Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 11 Aug 2024 22:45:54 +0200 Subject: [PATCH 21/55] Added a function _compute_mpp_tolerances which checks the mpp tolerances to BaseWSIReader; Edited docstrings --- monai/data/wsi_reader.py | 163 +++++++++++++++++---------------------- 1 file changed, 70 insertions(+), 93 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 81afafb246..7b2e2eb0db 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,27 +431,62 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata - def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: tuple): """ - Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + Computes the target dimensions for resizing a whole slide image + to match a user-specified resolution in microns per pixel (MPP). Args: - wsi: whole slide image object from WSIReader - user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. - closest_lvl: the wsi level that is closest to the user-provided mpp resolution. - mpp_list: list of mpp values for all levels of a whole slide image. + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + closest_lvl_dim: Dimensions (height, width) of the image at the closest level. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. """ mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + ds_factor_x = mpp_closest_lvl_x / mpp[0] + ds_factor_y = mpp_closest_lvl_y / mpp[1] target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) return target_res_x, target_res_y + + def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> bool: + """ + Determines if user-provided MPP values are within a specified tolerance of the closest + level's MPP and checks if the closest level has higher resolution than desired MPP. + + Args: + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. + atol: Absolute tolerance for MPP comparison. + rtol: Relative tolerance for MPP comparison. + + """ + user_mpp_x, user_mpp_y = mpp + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + is_within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + is_within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + is_within_tolerance = is_within_tolerance_x & is_within_tolerance_y + + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + return is_within_tolerance, closest_level_is_bigger def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -802,50 +837,27 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 rtol: the acceptable relative tolerance for resolution in micro per pixel. """ - - # cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - - # closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - - # mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region( (0, 0), level=closest_lvl, size=wsi.resolutions["level_dimensions"][closest_lvl], num_workers=self.num_workers ) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - # Else: increase resolution (ie, decrement level) and then downsample - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -1087,43 +1099,25 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region( (0, 0), level=closest_lvl, size=wsi.level_dimensions[closest_lvl] ) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - # Else: increase resolution (ie, decrement level) and then downsample - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1342,40 +1336,23 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr From 547442ed138734a29c472eb97613b8e6ad2a8e4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Aug 2024 20:46:19 +0000 Subject: [PATCH 22/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 7b2e2eb0db..57df016140 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -453,7 +453,7 @@ def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: t target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) return target_res_x, target_res_y - + def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> bool: """ Determines if user-provided MPP values are within a specified tolerance of the closest From 54774d903b45babc05adadf12e8a99d4074af048 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Tue, 4 Mar 2025 11:12:27 +0100 Subject: [PATCH 23/55] Updated WSI reader --- monai/data/wsi_reader.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 57df016140..8b06a0dd06 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -842,7 +842,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + within_tolerance, closest_level_is_bigger = self._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. @@ -962,7 +962,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) @@ -1067,7 +1067,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and "tiff.YResolution" in wsi.properties and wsi.properties["tiff.YResolution"] and wsi.properties["tiff.XResolution"] - ): + ): unit = wsi.properties.get("tiff.ResolutionUnit") if unit is None: warnings.warn("The resolution unit is missing, `micrometer` will be used as default.") @@ -1102,7 +1102,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + within_tolerance, closest_level_is_bigger = self._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. @@ -1207,7 +1207,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_dim = wsi.level_dimensions[closest_lvl] - target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) @@ -1339,7 +1339,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + within_tolerance, closest_level_is_bigger = self._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. @@ -1449,7 +1449,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_dim = self.get_size(wsi, closest_lvl) - target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) From 3f87b1055beb23c5853bb09fddbcc313f9164d6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Mar 2025 10:19:38 +0000 Subject: [PATCH 24/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 8b06a0dd06..70f32110ff 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1067,7 +1067,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and "tiff.YResolution" in wsi.properties and wsi.properties["tiff.YResolution"] and wsi.properties["tiff.XResolution"] - ): + ): unit = wsi.properties.get("tiff.ResolutionUnit") if unit is None: warnings.warn("The resolution unit is missing, `micrometer` will be used as default.") From f70d0ef6205a61fcc1164aa1e3b6e0046de2da26 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Thu, 17 Jul 2025 02:01:34 +0200 Subject: [PATCH 25/55] Added get_wsi_at_mpp tests; fixed a few bugs --- monai/data/wsi_reader.py | 9 ++- tests/utils/enums/test_wsireader.py | 103 ++++++++++++++++++++++++---- 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 70f32110ff..2962108414 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -965,7 +965,12 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + closest_lvl_wsi = cucim_resize( + wsi_arr, + (target_res_x, target_res_y), + order=1, + preserve_range=True, + anti_aliasing=False).astype(cp.uint8) return closest_lvl_wsi @@ -1210,7 +1215,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_y, target_res_x), pil_image.BILINEAR) # row, col order return closest_lvl_wsi diff --git a/tests/utils/enums/test_wsireader.py b/tests/utils/enums/test_wsireader.py index 3b84af7345..df06bf3ec7 100644 --- a/tests/utils/enums/test_wsireader.py +++ b/tests/utils/enums/test_wsireader.py @@ -37,9 +37,12 @@ has_tiff = has_tiff and has_codec TESTS_PATH = Path(__file__).parents[2] -WSI_GENERIC_TIFF_KEY = "wsi_generic_tiff" +WSI_GENERIC_TIFF_KEY = "wsi_generic_tiff" # TIFF image with incorrect mpp values WSI_GENERIC_TIFF_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{WSI_GENERIC_TIFF_KEY}.tiff") +WSI_GENERIC_TIFF_CORRECT_MPP_KEY = "wsi_generic_tiff_corrected" +WSI_GENERIC_TIFF_CORRECT_MPP_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{WSI_GENERIC_TIFF_CORRECT_MPP_KEY}.tiff") + WSI_APERIO_SVS_KEY = "wsi_aperio_svs" WSI_APERIO_SVS_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{WSI_APERIO_SVS_KEY}.svs") @@ -256,6 +259,54 @@ "cpu", ] +TEST_CASE_SVS_MPP_1 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (4.0, 4.0), "atol": 0.0, "rtol": 0.1}, + {"openslide": (4106, 5739, 4), "cucim": (4106, 5739, 3)}, +] + +TEST_CASE_SVS_MPP_2 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (8.0, 8.0)}, + {"openslide": (2057, 2875, 4), "cucim": (2057, 2875, 3)}, +] + +TEST_CASE_SVS_MPP_3 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (3.0, 3.0)}, + {"openslide": (5475, 7652, 4), "cucim": (5475, 7652, 3)}, +] + +TEST_CASE_SVS_MPP_4 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (1.5, 1.5)}, + {"openslide": (10949, 15303, 4), "cucim": (10949, 15303, 3)}, +] + +TEST_CASE_TIFF_MPP_1 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (4.0, 4.0), "atol": 0.0, "rtol": 0.1}, + {"openslide": (4114, 5750, 4), "cucim": (4114, 5750, 3), "tifffile": (4106, 5739, 3)}, +] + +TEST_CASE_TIFF_MPP_2 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (8.0, 8.0)}, + {"openslide": (2057, 2875, 4), "cucim": (2057, 2875, 3), "tifffile": (2053, 2869, 3)}, +] + +TEST_CASE_TIFF_MPP_3 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (3.0, 3.0)}, + {"openslide": (5475, 7652, 4), "cucim": (5475, 7652, 3), "tifffile": (5475, 7651, 3)}, +] + +TEST_CASE_TIFF_MPP_4 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (1.5, 1.5)}, + {"openslide": (10949, 15303, 4), "cucim": (10949, 15303, 3), "tifffile": (10949, 15303, 3)}, +] + TEST_CASE_DEVICE_2 = [ WSI_GENERIC_TIFF_PATH, {"level": 8, "dtype": torch.float32, "device": "cuda"}, @@ -407,17 +458,45 @@ class WSIReaderTests: class Tests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_WHOLE_0]) - def test_read_whole_image(self, file_path, level, expected_shape): - reader = WSIReader(self.backend, level=level) - with reader.read(file_path) as img_obj: - img, meta = reader.get_data(img_obj) - self.assertTupleEqual(img.shape, expected_shape) - self.assertEqual(meta["backend"], self.backend) - self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) - self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) - assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) + # @parameterized.expand([TEST_CASE_WHOLE_0]) + # def test_read_whole_image(self, file_path, level, expected_shape): + # reader = WSIReader(self.backend, level=level) + # with reader.read(file_path) as img_obj: + # img, meta = reader.get_data(img_obj) + # self.assertTupleEqual(img.shape, expected_shape) + # self.assertEqual(meta["backend"], self.backend) + # self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) + # self.assertEqual(meta[WSIPatchKeys.LEVEL], level) + # assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) + # assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) + + @parameterized.expand( + [ + TEST_CASE_SVS_MPP_1, + TEST_CASE_SVS_MPP_2, + TEST_CASE_SVS_MPP_3, + TEST_CASE_SVS_MPP_4, + TEST_CASE_TIFF_MPP_1, + TEST_CASE_TIFF_MPP_2, + TEST_CASE_TIFF_MPP_3, + TEST_CASE_TIFF_MPP_4 + ] + ) + def test_get_wsi_at_mpp(self, file_path, func_kwargs, expected_shape): + # Tifffile backend cannot read MPP from the SVS file, so skip. + if self.backend == "tifffile" and file_path == WSI_APERIO_SVS_PATH: + self.skipTest("TiffFileWSIReader cannot extract MPP from SVS files.") + + # Look up the expected shape for the current backend + if self.backend not in expected_shape: + self.skipTest(f"No expected shape defined for backend '{self.backend}' in this test case.") + expected_shape = expected_shape[self.backend] + + reader = WSIReader(self.backend) + with reader.read(file_path) as wsi: + wsi_arr = reader.get_wsi_at_mpp(wsi, **func_kwargs) + + self.assertTupleEqual(wsi_arr.shape, expected_shape) @parameterized.expand( [ From 61fc9bf9a5e4951708b6e98ed69d296affb9723a Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 22 Mar 2024 14:07:47 +0100 Subject: [PATCH 26/55] Added function get_img_at_mpp to class OpenSlideWSIReader of module wsi_reader.py --- monai/data/wsi_reader.py | 82 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 2a4fe9f7a8..587b0336b3 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -19,6 +19,7 @@ import numpy as np import torch +import cv2 from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images @@ -940,6 +941,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + mpp_closest_lvl = mpp_list[closest_lvl] + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + return closest_lvl_wsi + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + return closest_lvl_wsi + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + return closest_lvl_wsi + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From a834514f4b8464040437543c3e424e7487be92e7 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 22 Mar 2024 16:42:18 +0100 Subject: [PATCH 27/55] Added get_img_at_mpp to class CuCIMWSIReader --- monai/data/wsi_reader.py | 98 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 587b0336b3..6b8597b72f 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,6 +603,23 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) + + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + return self.reader.get_img_at_mpp(wsi, mpp, atol, rtol) def get_power(self, wsi, level: int) -> float: """ @@ -745,6 +762,87 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + mpp_closest_lvl = mpp_list[closest_lvl] + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + + return closest_lvl_wsi + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + return closest_lvl_wsi + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + return closest_lvl_wsi + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From c92475d66fca9d584cecce47885d271f8ade2184 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 19:18:42 +0100 Subject: [PATCH 28/55] Added function get_img_at_mpp to class TifffileWSIReader; changed resizing function to Image.resize, cucim.skimage.transform.resize --- monai/data/wsi_reader.py | 154 ++++++++++++++++++++++++++++++++------- 1 file changed, 126 insertions(+), 28 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 6b8597b72f..b3a75783f6 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -19,7 +19,6 @@ import numpy as np import torch -import cv2 from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data.image_reader import ImageReader, _stack_images @@ -778,9 +777,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ + cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + cp, _ = optional_import("cupy") + user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] @@ -797,13 +801,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - return closest_lvl_wsi else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x @@ -814,15 +817,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - return closest_lvl_wsi + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -833,15 +837,18 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_y = mpp_closest_lvl_y / user_mpp_y closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') - return closest_lvl_wsi + + wsi_arr = cp.asnumpy(closest_lvl_wsi) + return wsi_arr def get_power(self, wsi, level: int) -> float: """ @@ -1055,9 +1062,12 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ + pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # Should not throw ValueError, instead just return the closest value; + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.level_dimensions[closest_lvl] @@ -1078,9 +1088,8 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - return closest_lvl_wsi else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x @@ -1091,15 +1100,14 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) - + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - return closest_lvl_wsi + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -1110,15 +1118,16 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_y = mpp_closest_lvl_y / user_mpp_y closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = np.array(wsi.read_region(location=(0, 0), level=closest_lvl, size=closest_lvl_dim))[:, :, :3] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = cv2.resize(closest_lvl_wsi, dsize=(target_res_x, target_res_y), interpolation=cv2.INTER_LINEAR) - + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') - return closest_lvl_wsi + + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr def get_power(self, wsi, level: int) -> float: """ @@ -1290,6 +1299,95 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + """ + Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. + The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. + If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. + + Args: + wsi: whole slide image object from WSIReader + mpp: the resolution in micron per pixel at which the representation of the whole slide image should be extracted. + atol: the acceptable absolute tolerance for resolution in micro per pixel. + rtol: the acceptable relative tolerance for resolution in micro per pixel. + + """ + + pil_image, _ = optional_import("PIL", name="Image") + user_mpp_x, user_mpp_y = mpp + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # QuPath show 4 levels in the pyramid, but len(wsi.pages) is 1? + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? + + mpp_closest_lvl = mpp_list[closest_lvl] + + lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] # Returns size in (height, width) + closest_lvl_dim = lvl_dims[closest_lvl] + closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + + print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + within_tolerance = within_tolerance_x & within_tolerance_y + + if within_tolerance: + # Take closest_level and continue with returning img at level + print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + + else: + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + if closest_level_is_bigger: + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + + else: + # Else: increase resolution (ie, decrement level) and then downsample + closest_lvl = closest_lvl - 1 + mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp_x + ds_factor_y = mpp_closest_lvl_y / user_mpp_y + + closest_lvl_dim = lvl_dims[closest_lvl] + closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + + wsi_arr = np.array(closest_lvl_wsi) + return wsi_arr + def get_power(self, wsi, level: int) -> float: """ Returns the objective power of the whole slide image at a given level. From 5b988a2669d8ccc1a1847630e605a9d983cc7e8f Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 23:17:37 +0100 Subject: [PATCH 29/55] Small changes --- monai/data/wsi_reader.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index b3a75783f6..8a97da47ba 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -607,8 +607,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -765,8 +767,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -786,7 +790,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] # x,y notation print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -805,7 +809,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) # size in x,y notation else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -823,8 +827,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) # output_shape in row, col notation print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') else: @@ -843,7 +846,6 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - # closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), Image.BILINEAR) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') @@ -1050,8 +1052,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -1123,7 +1127,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) # Output size in x,y notation print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') wsi_arr = np.array(closest_lvl_wsi) @@ -1303,7 +1307,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level—indicating that the closest level has a higher resolution than requested—the image is downscaled to a resolution that matches the user-provided mpp. + If the user-provided mpp is larger than the mpp of the closest level the image is downscaled to a resolution that matches the user-provided mpp. Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. @@ -1317,8 +1321,8 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # QuPath show 4 levels in the pyramid, but len(wsi.pages) is 1? - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] @@ -1356,7 +1360,6 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) @@ -1376,7 +1379,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - # closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) From 5168a12efaa8b4c3e13ba83b78c9af8c51ea0774 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 24 Mar 2024 23:21:57 +0100 Subject: [PATCH 30/55] Small changes --- monai/data/wsi_reader.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 8a97da47ba..04ee7cec32 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -790,7 +790,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] # x,y notation + closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -809,7 +809,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) # size in x,y notation + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -827,13 +827,13 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) # output_shape in row, col notation + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl ds_factor_x = mpp_closest_lvl_x / user_mpp_x @@ -1115,7 +1115,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP + mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl ds_factor_x = mpp_closest_lvl_x / user_mpp_x @@ -1127,7 +1127,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) # Output size in x,y notation + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') wsi_arr = np.array(closest_lvl_wsi) @@ -1307,8 +1307,10 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. - If the user-provided mpp is larger than the mpp of the closest level the image is downscaled to a resolution that matches the user-provided mpp. - Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, the next lower level (which has a higher resolution) is chosen. + If the user-provided mpp is larger than the mpp of the closest level, + the image is downscaled to a resolution that matches the user-provided mpp. + Otherwise, if the closest level's resolution is not sufficient to meet the user's requested resolution, + the next lower level (which has a higher resolution) is chosen. The image from this level is then down-scaled to achieve a resolution at the user-provided mpp value. Args: @@ -1327,7 +1329,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_closest_lvl = mpp_list[closest_lvl] - lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] # Returns size in (height, width) + lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) From 143747788775ef4308af54f78ee11f0c82ed7689 Mon Sep 17 00:00:00 2001 From: cxlcl Date: Fri, 22 Mar 2024 09:54:40 -0700 Subject: [PATCH 31/55] Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient (#7308) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on the discussion topic [here](https://github.com/Project-MONAI/MONAI/discussions/7161#discussion-5773293), we implemented the Conjugate-Gradient algorithm for linear operator inversion, and Stein's Unbiased Risk Estimator (SURE) [1] loss for ground-truth-date free diffusion process guidance that is proposed in [2] and illustrated in the algorithm below: Screenshot 2023-12-10 at 10 19 25 PM The Conjugate-Gradient (CG) algorithm is used to solve for the inversion of the linear operator in Line-4 in the algorithm above, where the linear operator is too large to store explicitly as a matrix (such as FFT/IFFT of an image) and invert directly. Instead, we can solve for the linear inversion iteratively as in CG. The SURE loss is applied for Line-6 above. This is a differentiable loss function that can be used to train/giude an operator (e.g. neural network), where the pseudo ground truth is available but the reference ground truth is not. For example, in the MRI reconstruction, the pseudo ground truth is the zero-filled reconstruction and the reference ground truth is the fully sampled reconstruction. The reference ground truth is not available due to the lack of fully sampled. **Reference** [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics 1981 [[paper link](https://projecteuclid.org/journals/annals-of-statistics/volume-9/issue-6/Estimation-of-the-Mean-of-a-Multivariate-Normal-Distribution/10.1214/aos/1176345632.full)] [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. MICCAI 2023 [[paper link](https://arxiv.org/pdf/2310.01799.pdf)] - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Nikolas Schmitz --- tests/test_conjugate_gradient.py | 55 +++++++++++++++++++++++++ tests/test_sure_loss.py | 71 ++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 tests/test_conjugate_gradient.py create mode 100644 tests/test_sure_loss.py diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py new file mode 100644 index 0000000000..239dbe3ecd --- /dev/null +++ b/tests/test_conjugate_gradient.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.layers import ConjugateGradient + + +class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): + """Test ConjugateGradient with real-valued input: when the input is real + value, the output should be the inverse of the matrix.""" + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + # define the measurement + y = torch.tensor([1, 2, 3], dtype=torch.float) + # solve for x + x = cg_solver(torch.zeros(a_dim), y) + x_ref = torch.linalg.solve(a_mat, y) + # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + def test_complex_valued_inverse(self): + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + y = torch.tensor([1, 2, 3], dtype=torch.complex64) + x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) + x_ref = torch.linalg.solve(a_mat, y) + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py new file mode 100644 index 0000000000..945da657bf --- /dev/null +++ b/tests/test_sure_loss.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.losses import SURELoss + + +class TestSURELoss(unittest.TestCase): + def test_real_value(self): + """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" + sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) + + def operator(x): + return x + + y_pseudo_gt = torch.randn(2, 1, 128, 128) + x = torch.randn(2, 1, 128, 128) + loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_value(self): + """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" + + def operator(x): + return x + + sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1) + y_pseudo_gt = torch.randn(2, 2, 128, 128) + x = torch.randn(2, 2, 128, 128) + loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_general_input(self): + """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" + + def operator(x): + return x + + perturb_noise_real = torch.randn(2, 1, 128, 128) + perturb_noise_complex = torch.zeros(2, 2, 128, 128) + perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze() + y_pseudo_gt_real = torch.randn(2, 1, 128, 128) + y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) + y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze() + x_real = torch.randn(2, 1, 128, 128) + x_complex = torch.zeros(2, 2, 128, 128) + x_complex[:, 0, :, :] = x_real.squeeze() + + sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) + sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) + + loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) + loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + + +if __name__ == "__main__": + unittest.main() From a4a78e3a5209701a05cb8bdad36650652b14d636 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 00:13:12 +0100 Subject: [PATCH 32/55] Renamed function to get_wsi_at_mpp Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 04ee7cec32..0e116b43ef 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,7 +603,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -620,7 +620,7 @@ def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 rtol: the acceptable relative tolerance for resolution in micro per pixel. """ - return self.reader.get_img_at_mpp(wsi, mpp, atol, rtol) + return self.reader.get_wsi_at_mpp(wsi, mpp, atol, rtol) def get_power(self, wsi, level: int) -> float: """ @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1048,7 +1048,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1303,7 +1303,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_img_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From b1ed4ff1e211de7d5e4d82bbf8b08d1dc8c6ce53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Mar 2024 23:53:13 +0000 Subject: [PATCH 33/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 0e116b43ef..aca2ecd8c6 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -602,7 +602,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. @@ -829,7 +829,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') - + else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 @@ -1088,7 +1088,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') @@ -1324,7 +1324,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] - closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) + closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] @@ -1346,7 +1346,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) within_tolerance = within_tolerance_x & within_tolerance_y - + if within_tolerance: # Take closest_level and continue with returning img at level print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') From 6b18e9b3c471213e6db863766a24f3ded32d8c47 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 02:21:13 +0100 Subject: [PATCH 34/55] Reformatted wsi_reader.py Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 42 +++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index aca2ecd8c6..af2467e67c 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -785,14 +785,14 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 cp, _ = optional_import("cupy") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions['level_count'])] + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -808,8 +808,10 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -821,14 +823,16 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) wsi_arr = cp.array(closest_lvl_wsi) target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -839,15 +843,17 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 ds_factor_x = mpp_closest_lvl_x / user_mpp_x ds_factor_y = mpp_closest_lvl_y / user_mpp_y - closest_lvl_dim = wsi.resolutions['level_dimensions'][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + ) wsi_arr = cp.array(closest_lvl_wsi) target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -1075,7 +1081,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_closest_lvl = mpp_list[closest_lvl] closest_lvl_dim = wsi.level_dimensions[closest_lvl] - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -1091,7 +1097,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) else: @@ -1110,7 +1116,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -1128,7 +1134,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1333,7 +1339,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl_dim = lvl_dims[closest_lvl] closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - print(f'Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}') + print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl # Define tolerance intervals for x and y of closest level @@ -1349,7 +1355,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f'User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.') + print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) else: @@ -1368,7 +1374,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}') + print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") else: # Else: increase resolution (ie, decrement level) and then downsample @@ -1388,7 +1394,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f'Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}') + print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") wsi_arr = np.array(closest_lvl_wsi) return wsi_arr From d74d4befd7a04cb536797b442aa8b058dc2935d3 Mon Sep 17 00:00:00 2001 From: monai-bot <64792179+monai-bot@users.noreply.github.com> Date: Mon, 25 Mar 2024 07:26:43 +0000 Subject: [PATCH 35/55] auto updates (#7577) Signed-off-by: monai-bot Signed-off-by: monai-bot Signed-off-by: Nikolas Schmitz --- tests/test_conjugate_gradient.py | 1 + tests/test_sure_loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py index 239dbe3ecd..64efe3b168 100644 --- a/tests/test_conjugate_gradient.py +++ b/tests/test_conjugate_gradient.py @@ -19,6 +19,7 @@ class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): """Test ConjugateGradient with real-valued input: when the input is real value, the output should be the inverse of the matrix.""" diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 945da657bf..903f9bd2ca 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -19,6 +19,7 @@ class TestSURELoss(unittest.TestCase): + def test_real_value(self): """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) From d6d0cc24f0f2b16112f4478a08b46477d6f9a336 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 11:18:03 +0100 Subject: [PATCH 36/55] Fixed return type Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index af2467e67c..2045fce961 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -603,7 +603,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: """ return self.reader.get_mpp(wsi, level) - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1054,7 +1054,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From 2f4388eb0ec1b86d833afe204dd48681232730d3 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Mon, 25 Mar 2024 12:07:06 +0100 Subject: [PATCH 37/55] Small fixes Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 2045fce961..50f404d5f5 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -763,7 +763,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> Any: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. @@ -1309,7 +1309,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") - def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.array: + def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ Returns the representation of the whole slide image at a given micro-per-pixel (mpp) resolution. The optional tolerance parameters are considered at the level whose mpp value is closest to the one provided by the user. From d0a4881b193ad26c5794b9b1d57b8a924a87b577 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Wed, 31 Jul 2024 19:29:47 +0200 Subject: [PATCH 38/55] Updated function get_wsi_at_mpp; added function _resize_to_mpp_res to reduce redundancy; for get_mpp of TiffFileWSIReader: added check to prevent division by zero error. --- monai/data/wsi_reader.py | 228 ++++++++++++++++++++------------------- 1 file changed, 117 insertions(+), 111 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 50f404d5f5..8f25b5c359 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -781,7 +781,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + # cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") user_mpp_x, user_mpp_y = mpp @@ -789,11 +789,10 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + # closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + # mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -808,9 +807,8 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers + (0, 0), level=closest_lvl, size=wsi.resolutions["level_dimensions"][closest_lvl], num_workers=self.num_workers ) else: @@ -820,40 +818,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers - ) - wsi_arr = cp.array(closest_lvl_wsi) - - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region( - (0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers - ) - wsi_arr = cp.array(closest_lvl_wsi) - - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) - - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=0) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -941,6 +911,36 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") + cp, _ = optional_import("cupy") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) + wsi_arr = cp.array(closest_lvl_wsi) + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @require_pkg(pkg_name="openslide") @@ -1072,17 +1072,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -1097,8 +1092,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + closest_lvl_wsi = wsi.read_region( + (0, 0), level=closest_lvl, size=wsi.level_dimensions[closest_lvl] + ) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -1107,34 +1103,12 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1208,6 +1182,34 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = wsi.level_dimensions[closest_lvl] + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @require_pkg(pkg_name="tifffile") @@ -1302,12 +1304,16 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: unit = "micrometer" convert_to_micron = ConvertUnits(unit, "micrometer") - # Here x and y resolutions are rational numbers so each of them is represented by a tuple. + + # Here, x and y resolutions are rational numbers so each of them is represented by a tuple. yres = wsi.pages[level].tags["YResolution"].value xres = wsi.pages[level].tags["XResolution"].value - return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) - - raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") + if xres[0] & yres[0]: + return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) + else: + raise ValueError("The `XResolution` and/or `YResolution` property of the image is zero, " + "which is needed to obtain `mpp` for this file. Please use `level` instead.") + raise ValueError("`mpp` cannot be obtained for this file. Please use `level` instead.") def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05) -> np.ndarray: """ @@ -1329,18 +1335,15 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl = mpp_list[closest_lvl] - - lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] - closest_lvl_dim = lvl_dims[closest_lvl] - closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) + # lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] + # closest_lvl_dim = lvl_dims[closest_lvl] + # closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - print(f"Closest Level: {closest_lvl} with MPP: {mpp_closest_lvl}") - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol @@ -1355,8 +1358,8 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print(f"User-provided MPP lies within tolerance of level {closest_lvl}, returning wsi at this level.") - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) + print('Tifffile, within tolerance') + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) else: # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp @@ -1365,36 +1368,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y if closest_level_is_bigger: - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 1: Downscaling using factor {(ds_factor_x, ds_factor_y)}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) else: - # Else: increase resolution (ie, decrement level) and then downsample closest_lvl = closest_lvl - 1 - mpp_closest_lvl = mpp_list[closest_lvl] # Update MPP - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp_x - ds_factor_y = mpp_closest_lvl_y / user_mpp_y - - closest_lvl_dim = lvl_dims[closest_lvl] - closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) # Might be suboptimal - - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) - - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - print(f"Case 2: Downscaling using factor {(ds_factor_x, ds_factor_y)}, now from level {closest_lvl}") + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1443,7 +1421,7 @@ def _get_patch( Extracts and returns a patch image form the whole slide image. Args: - wsi: a whole slide image object loaded from a file or a lis of such objects + wsi: a whole slide image object loaded from a file or a list of such objects location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0). size: (height, width) tuple giving the patch size at the given level (`level`). If None, it is set to the full image size at the given level. @@ -1475,3 +1453,31 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch + + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + pil_image, _ = optional_import("PIL", name="Image") + + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + closest_lvl_dim = self.get_size(wsi, closest_lvl) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) + + target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From 7545148a4bb937855e6e84841899c33aa33699f9 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Wed, 31 Jul 2024 19:54:29 +0200 Subject: [PATCH 39/55] Minor fixes: removed unnecessary comments --- monai/data/wsi_reader.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 8f25b5c359..53048c2d70 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1333,16 +1333,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - pil_image, _ = optional_import("PIL", name="Image") user_mpp_x, user_mpp_y = mpp - mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles + mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - # lvl_dims = [self.get_size(wsi, lvl) for lvl in range(len(wsi.pages))] - # closest_lvl_dim = lvl_dims[closest_lvl] - # closest_lvl_dim = (closest_lvl_dim[1], closest_lvl_dim[0]) - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] # Define tolerance intervals for x and y of closest level @@ -1358,7 +1353,6 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 if within_tolerance: # Take closest_level and continue with returning img at level - print('Tifffile, within tolerance') closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) else: From b01bf632dcc3c3d45c9e8550cddc636d706c0a30 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 22:20:38 +0000 Subject: [PATCH 40/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 53048c2d70..1ba799f095 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -911,17 +911,17 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") @@ -1182,17 +1182,17 @@ def _get_patch( patch = np.moveaxis(patch, -1, self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") @@ -1447,11 +1447,11 @@ def _get_patch( patch = np.take(patch, [0, 1, 2], self.channel_dim) return patch - + def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. From 5c7822f8f706e54b4bc803d61f08a631e1054af7 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 4 Aug 2024 17:53:19 +0200 Subject: [PATCH 41/55] Added new feature and merged updates from main repository --- monai/data/wsi_reader.py | 61 +++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 1ba799f095..e217f41c7e 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,6 +431,28 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + """ + Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + + Args: + wsi: whole slide image object from WSIReader + user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. + closest_lvl: the wsi level that is closest to the user-provided mpp resolution. + mpp_list: list of mpp values for all levels of a whole slide image. + + """ + mpp_closest_lvl = mpp_list[closest_lvl] + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl + + ds_factor_x = mpp_closest_lvl_x / user_mpp[0] + ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + + target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) + target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + + return target_res_x, target_res_y + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by WSI reader. @@ -926,20 +948,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers) - wsi_arr = cp.array(closest_lvl_wsi) - target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + return closest_lvl_wsi @@ -1196,19 +1211,13 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = wsi.level_dimensions[closest_lvl] - closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi @@ -1457,21 +1466,15 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") - mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] - closest_lvl_dim = self.get_size(wsi, closest_lvl) - closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) - target_res_x = int(np.round(closest_lvl_dim[0] * ds_factor_x)) - target_res_y = int(np.round(closest_lvl_dim[1] * ds_factor_y)) + target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + return closest_lvl_wsi From fd0d0cf24df3181410d6e3ee923cf01ea0da0516 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:36:01 +0000 Subject: [PATCH 42/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index e217f41c7e..81afafb246 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -434,13 +434,13 @@ def get_data( def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): """ Resizes the whole slide image to the specified resolution in microns per pixel (mpp). - + Args: wsi: whole slide image object from WSIReader user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -1466,7 +1466,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. - + """ pil_image, _ = optional_import("PIL", name="Image") From 234f23f6495ad5864e6e74d6251eba6ebcd2a317 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Sun, 11 Aug 2024 22:45:54 +0200 Subject: [PATCH 43/55] Added a function _compute_mpp_tolerances which checks the mpp tolerances to BaseWSIReader; Edited docstrings --- monai/data/wsi_reader.py | 163 +++++++++++++++++---------------------- 1 file changed, 70 insertions(+), 93 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 81afafb246..7b2e2eb0db 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -431,27 +431,62 @@ def get_data( metadata[key] = [m[key] for m in metadata_list] return _stack_images(patch_list, metadata), metadata - def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, user_mpp: tuple): + def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: tuple): """ - Resizes the whole slide image to the specified resolution in microns per pixel (mpp). + Computes the target dimensions for resizing a whole slide image + to match a user-specified resolution in microns per pixel (MPP). Args: - wsi: whole slide image object from WSIReader - user_mpp: the resolution in microns per pixel at which the whole slide image representation should be extracted. - closest_lvl: the wsi level that is closest to the user-provided mpp resolution. - mpp_list: list of mpp values for all levels of a whole slide image. + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + closest_lvl_dim: Dimensions (height, width) of the image at the closest level. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. """ mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl - ds_factor_x = mpp_closest_lvl_x / user_mpp[0] - ds_factor_y = mpp_closest_lvl_y / user_mpp[1] + ds_factor_x = mpp_closest_lvl_x / mpp[0] + ds_factor_y = mpp_closest_lvl_y / mpp[1] target_res_x = int(np.round(closest_lvl_dim[1] * ds_factor_x)) target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) return target_res_x, target_res_y + + def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> bool: + """ + Determines if user-provided MPP values are within a specified tolerance of the closest + level's MPP and checks if the closest level has higher resolution than desired MPP. + + Args: + closest_lvl: Whole slide image level closest to user-provided MPP resolution. + mpp_list: List of MPP values for all levels of the whole slide image. + mpp: The MPP resolution at which the whole slide image representation should be extracted. + atol: Absolute tolerance for MPP comparison. + rtol: Relative tolerance for MPP comparison. + + """ + user_mpp_x, user_mpp_y = mpp + mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] + + # Define tolerance intervals for x and y of closest level + lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol + upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol + lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol + upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol + + # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level + is_within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) + is_within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) + is_within_tolerance = is_within_tolerance_x & is_within_tolerance_y + + # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp + closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x + closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y + closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + + return is_within_tolerance, closest_level_is_bigger def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -802,50 +837,27 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 rtol: the acceptable relative tolerance for resolution in micro per pixel. """ - - # cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - - # closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - - # mpp_closest_lvl = mpp_list[closest_lvl] - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region( (0, 0), level=closest_lvl, size=wsi.resolutions["level_dimensions"][closest_lvl], num_workers=self.num_workers ) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - # Else: increase resolution (ie, decrement level) and then downsample - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -1087,43 +1099,25 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region( (0, 0), level=closest_lvl, size=wsi.level_dimensions[closest_lvl] ) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - # Else: increase resolution (ie, decrement level) and then downsample - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1342,40 +1336,23 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 """ - user_mpp_x, user_mpp_y = mpp mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - # -> Should not throw ValueError, instead just return the closest value; how to select tolerances? - - mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] - - # Define tolerance intervals for x and y of closest level - lower_bound_x = mpp_closest_lvl_x * (1 - rtol) - atol - upper_bound_x = mpp_closest_lvl_x * (1 + rtol) + atol - lower_bound_y = mpp_closest_lvl_y * (1 - rtol) - atol - upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol - # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - within_tolerance = within_tolerance_x & within_tolerance_y + within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: - # Take closest_level and continue with returning img at level + # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=self.get_size(wsi, closest_lvl)) - else: - # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp - closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x - closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y - - if closest_level_is_bigger: - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + elif closest_level_is_bigger: + # Otherwise, select the level closest to the desired mpp with a higher resolution and downsample it. + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - else: - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr From 2730abe324fe6ca1dd7e9672f51e7a49e8352f0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Aug 2024 20:46:19 +0000 Subject: [PATCH 44/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 7b2e2eb0db..57df016140 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -453,7 +453,7 @@ def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: t target_res_y = int(np.round(closest_lvl_dim[0] * ds_factor_y)) return target_res_x, target_res_y - + def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> bool: """ Determines if user-provided MPP values are within a specified tolerance of the closest From 9d817e792c056d65c084f37304098fa4c9b4ebd5 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Tue, 4 Mar 2025 11:12:27 +0100 Subject: [PATCH 45/55] Updated WSI reader --- monai/data/wsi_reader.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 57df016140..8b06a0dd06 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -842,7 +842,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.resolutions["level_count"])] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + within_tolerance, closest_level_is_bigger = self._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. @@ -962,7 +962,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_dim = wsi.resolutions["level_dimensions"][closest_lvl] - target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) @@ -1067,7 +1067,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and "tiff.YResolution" in wsi.properties and wsi.properties["tiff.YResolution"] and wsi.properties["tiff.XResolution"] - ): + ): unit = wsi.properties.get("tiff.ResolutionUnit") if unit is None: warnings.warn("The resolution unit is missing, `micrometer` will be used as default.") @@ -1102,7 +1102,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + within_tolerance, closest_level_is_bigger = self._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. @@ -1207,7 +1207,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_dim = wsi.level_dimensions[closest_lvl] - target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) @@ -1339,7 +1339,7 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles closest_lvl = self._find_closest_level("mpp", mpp, mpp_list, 0, 5) - within_tolerance, closest_level_is_bigger = super()._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) + within_tolerance, closest_level_is_bigger = self._compute_mpp_tolerances(closest_lvl, mpp_list, mpp, atol, rtol) if within_tolerance: # If the image at the desired mpp resolution is within tolerances, return the image at closest_level. @@ -1449,7 +1449,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_dim = self.get_size(wsi, closest_lvl) - target_res_x, target_res_y = super()._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) + target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) From 2afa6fb6913e732135803feec091efcc450d7ec9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Mar 2025 10:19:38 +0000 Subject: [PATCH 46/55] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/wsi_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 8b06a0dd06..70f32110ff 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1067,7 +1067,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: and "tiff.YResolution" in wsi.properties and wsi.properties["tiff.YResolution"] and wsi.properties["tiff.XResolution"] - ): + ): unit = wsi.properties.get("tiff.ResolutionUnit") if unit is None: warnings.warn("The resolution unit is missing, `micrometer` will be used as default.") From 8270658ccfa4b0d0d7b943d78ecec7696ac6c1de Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Thu, 17 Jul 2025 02:01:34 +0200 Subject: [PATCH 47/55] Added get_wsi_at_mpp tests; fixed a few bugs --- monai/data/wsi_reader.py | 9 ++- tests/utils/enums/test_wsireader.py | 103 ++++++++++++++++++++++++---- 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 70f32110ff..2962108414 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -965,7 +965,12 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) wsi_arr = cp.array(wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim, num_workers=self.num_workers)) - closest_lvl_wsi = cucim_resize(wsi_arr, (target_res_x, target_res_y), order=1) + closest_lvl_wsi = cucim_resize( + wsi_arr, + (target_res_x, target_res_y), + order=1, + preserve_range=True, + anti_aliasing=False).astype(cp.uint8) return closest_lvl_wsi @@ -1210,7 +1215,7 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): target_res_x, target_res_y = self._compute_mpp_target_res(closest_lvl, closest_lvl_dim, mpp_list, user_mpp) closest_lvl_wsi = wsi.read_region((0, 0), level=closest_lvl, size=closest_lvl_dim) - closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) + closest_lvl_wsi = closest_lvl_wsi.resize((target_res_y, target_res_x), pil_image.BILINEAR) # row, col order return closest_lvl_wsi diff --git a/tests/utils/enums/test_wsireader.py b/tests/utils/enums/test_wsireader.py index 3b84af7345..df06bf3ec7 100644 --- a/tests/utils/enums/test_wsireader.py +++ b/tests/utils/enums/test_wsireader.py @@ -37,9 +37,12 @@ has_tiff = has_tiff and has_codec TESTS_PATH = Path(__file__).parents[2] -WSI_GENERIC_TIFF_KEY = "wsi_generic_tiff" +WSI_GENERIC_TIFF_KEY = "wsi_generic_tiff" # TIFF image with incorrect mpp values WSI_GENERIC_TIFF_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{WSI_GENERIC_TIFF_KEY}.tiff") +WSI_GENERIC_TIFF_CORRECT_MPP_KEY = "wsi_generic_tiff_corrected" +WSI_GENERIC_TIFF_CORRECT_MPP_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{WSI_GENERIC_TIFF_CORRECT_MPP_KEY}.tiff") + WSI_APERIO_SVS_KEY = "wsi_aperio_svs" WSI_APERIO_SVS_PATH = os.path.join(TESTS_PATH, "testing_data", f"temp_{WSI_APERIO_SVS_KEY}.svs") @@ -256,6 +259,54 @@ "cpu", ] +TEST_CASE_SVS_MPP_1 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (4.0, 4.0), "atol": 0.0, "rtol": 0.1}, + {"openslide": (4106, 5739, 4), "cucim": (4106, 5739, 3)}, +] + +TEST_CASE_SVS_MPP_2 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (8.0, 8.0)}, + {"openslide": (2057, 2875, 4), "cucim": (2057, 2875, 3)}, +] + +TEST_CASE_SVS_MPP_3 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (3.0, 3.0)}, + {"openslide": (5475, 7652, 4), "cucim": (5475, 7652, 3)}, +] + +TEST_CASE_SVS_MPP_4 = [ + WSI_APERIO_SVS_PATH, + {"mpp": (1.5, 1.5)}, + {"openslide": (10949, 15303, 4), "cucim": (10949, 15303, 3)}, +] + +TEST_CASE_TIFF_MPP_1 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (4.0, 4.0), "atol": 0.0, "rtol": 0.1}, + {"openslide": (4114, 5750, 4), "cucim": (4114, 5750, 3), "tifffile": (4106, 5739, 3)}, +] + +TEST_CASE_TIFF_MPP_2 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (8.0, 8.0)}, + {"openslide": (2057, 2875, 4), "cucim": (2057, 2875, 3), "tifffile": (2053, 2869, 3)}, +] + +TEST_CASE_TIFF_MPP_3 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (3.0, 3.0)}, + {"openslide": (5475, 7652, 4), "cucim": (5475, 7652, 3), "tifffile": (5475, 7651, 3)}, +] + +TEST_CASE_TIFF_MPP_4 = [ + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + {"mpp": (1.5, 1.5)}, + {"openslide": (10949, 15303, 4), "cucim": (10949, 15303, 3), "tifffile": (10949, 15303, 3)}, +] + TEST_CASE_DEVICE_2 = [ WSI_GENERIC_TIFF_PATH, {"level": 8, "dtype": torch.float32, "device": "cuda"}, @@ -407,17 +458,45 @@ class WSIReaderTests: class Tests(unittest.TestCase): backend = None - @parameterized.expand([TEST_CASE_WHOLE_0]) - def test_read_whole_image(self, file_path, level, expected_shape): - reader = WSIReader(self.backend, level=level) - with reader.read(file_path) as img_obj: - img, meta = reader.get_data(img_obj) - self.assertTupleEqual(img.shape, expected_shape) - self.assertEqual(meta["backend"], self.backend) - self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) - self.assertEqual(meta[WSIPatchKeys.LEVEL], level) - assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) - assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) + # @parameterized.expand([TEST_CASE_WHOLE_0]) + # def test_read_whole_image(self, file_path, level, expected_shape): + # reader = WSIReader(self.backend, level=level) + # with reader.read(file_path) as img_obj: + # img, meta = reader.get_data(img_obj) + # self.assertTupleEqual(img.shape, expected_shape) + # self.assertEqual(meta["backend"], self.backend) + # self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower()) + # self.assertEqual(meta[WSIPatchKeys.LEVEL], level) + # assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False) + # assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False) + + @parameterized.expand( + [ + TEST_CASE_SVS_MPP_1, + TEST_CASE_SVS_MPP_2, + TEST_CASE_SVS_MPP_3, + TEST_CASE_SVS_MPP_4, + TEST_CASE_TIFF_MPP_1, + TEST_CASE_TIFF_MPP_2, + TEST_CASE_TIFF_MPP_3, + TEST_CASE_TIFF_MPP_4 + ] + ) + def test_get_wsi_at_mpp(self, file_path, func_kwargs, expected_shape): + # Tifffile backend cannot read MPP from the SVS file, so skip. + if self.backend == "tifffile" and file_path == WSI_APERIO_SVS_PATH: + self.skipTest("TiffFileWSIReader cannot extract MPP from SVS files.") + + # Look up the expected shape for the current backend + if self.backend not in expected_shape: + self.skipTest(f"No expected shape defined for backend '{self.backend}' in this test case.") + expected_shape = expected_shape[self.backend] + + reader = WSIReader(self.backend) + with reader.read(file_path) as wsi: + wsi_arr = reader.get_wsi_at_mpp(wsi, **func_kwargs) + + self.assertTupleEqual(wsi_arr.shape, expected_shape) @parameterized.expand( [ From 23e4a74f29e8254d858ab2347de107db5dbdc497 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Thu, 17 Jul 2025 08:51:27 +0200 Subject: [PATCH 48/55] Added get_wsi_at_mpp tests; fixed a few bugs --- monai/data/wsi_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 2962108414..1ff4cc23f3 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1459,4 +1459,4 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - return closest_lvl_wsi + return closest_lvl_wsi \ No newline at end of file From 787d30f3c49fa57675c3a9a23d1957a10ed48328 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 18 Jul 2025 06:51:32 +0200 Subject: [PATCH 49/55] Added get_wsi_at_mpp tests; fixed a few bugs --- tests/test_sure_loss.py | 72 ----------------------------------------- 1 file changed, 72 deletions(-) delete mode 100644 tests/test_sure_loss.py diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py deleted file mode 100644 index 903f9bd2ca..0000000000 --- a/tests/test_sure_loss.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch - -from monai.losses import SURELoss - - -class TestSURELoss(unittest.TestCase): - - def test_real_value(self): - """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" - sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) - - def operator(x): - return x - - y_pseudo_gt = torch.randn(2, 1, 128, 128) - x = torch.randn(2, 1, 128, 128) - loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) - self.assertAlmostEqual(loss.item(), 0.0) - - def test_complex_value(self): - """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" - - def operator(x): - return x - - sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1) - y_pseudo_gt = torch.randn(2, 2, 128, 128) - x = torch.randn(2, 2, 128, 128) - loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) - self.assertAlmostEqual(loss.item(), 0.0) - - def test_complex_general_input(self): - """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" - - def operator(x): - return x - - perturb_noise_real = torch.randn(2, 1, 128, 128) - perturb_noise_complex = torch.zeros(2, 2, 128, 128) - perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze() - y_pseudo_gt_real = torch.randn(2, 1, 128, 128) - y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) - y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze() - x_real = torch.randn(2, 1, 128, 128) - x_complex = torch.zeros(2, 2, 128, 128) - x_complex[:, 0, :, :] = x_real.squeeze() - - sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) - sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) - - loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) - loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) - self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) - - -if __name__ == "__main__": - unittest.main() From 0d9f1dd4f20518e9db71d6fd55c7b0d38d108e17 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 18 Jul 2025 06:54:06 +0200 Subject: [PATCH 50/55] Added get_wsi_at_mpp tests; fixed a few bugs --- tests/test_conjugate_gradient.py | 56 -------------------------------- 1 file changed, 56 deletions(-) delete mode 100644 tests/test_conjugate_gradient.py diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py deleted file mode 100644 index 64efe3b168..0000000000 --- a/tests/test_conjugate_gradient.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch - -from monai.networks.layers import ConjugateGradient - - -class TestConjugateGradient(unittest.TestCase): - - def test_real_valued_inverse(self): - """Test ConjugateGradient with real-valued input: when the input is real - value, the output should be the inverse of the matrix.""" - a_dim = 3 - a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) - - def a_op(x): - return a_mat @ x - - cg_solver = ConjugateGradient(a_op, num_iter=100) - # define the measurement - y = torch.tensor([1, 2, 3], dtype=torch.float) - # solve for x - x = cg_solver(torch.zeros(a_dim), y) - x_ref = torch.linalg.solve(a_mat, y) - # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' - self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) - - def test_complex_valued_inverse(self): - a_dim = 3 - a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) - - def a_op(x): - return a_mat @ x - - cg_solver = ConjugateGradient(a_op, num_iter=100) - y = torch.tensor([1, 2, 3], dtype=torch.complex64) - x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) - x_ref = torch.linalg.solve(a_mat, y) - self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) - - -if __name__ == "__main__": - unittest.main() From 45182fa63c7dce2d346fc332d1f607518d990dc5 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 18 Jul 2025 06:58:50 +0200 Subject: [PATCH 51/55] Added get_wsi_at_mpp tests; fixed a few bugs --- monai/data/wsi_reader.py | 2 +- monai/transforms/regularization/array.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 1ff4cc23f3..2962108414 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1459,4 +1459,4 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl_wsi = pil_image.fromarray(wsi.pages[closest_lvl].asarray()) closest_lvl_wsi = closest_lvl_wsi.resize((target_res_x, target_res_y), pil_image.BILINEAR) - return closest_lvl_wsi \ No newline at end of file + return closest_lvl_wsi diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 66a5116c1a..768e3ed2bb 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -25,7 +25,6 @@ class Mixer(RandomizableTransform): - def __init__(self, batch_size: int, alpha: float = 1.0) -> None: """ Mixer is a base class providing the basic logic for the mixup-class of From 832c14e5afe2e6062eaa7263baee151d1b758c41 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 18 Jul 2025 20:59:26 +0200 Subject: [PATCH 52/55] Added get_wsi_at_mpp tests; fixed a few bugs --- monai/data/wsi_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 2962108414..f8db1a920f 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -1316,7 +1316,7 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]: # Here, x and y resolutions are rational numbers so each of them is represented by a tuple. yres = wsi.pages[level].tags["YResolution"].value xres = wsi.pages[level].tags["XResolution"].value - if xres[0] & yres[0]: + if xres[0] and yres[0]: return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0]) else: raise ValueError("The `XResolution` and/or `YResolution` property of the image is zero, " From 4249b4846d979d6d24262e875b5b187bda08fb1e Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 25 Jul 2025 19:15:39 +0200 Subject: [PATCH 53/55] DCO Remediation Commit for Nikolas Schmitz I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: d0a4881b193ad26c5794b9b1d57b8a924a87b577 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 7545148a4bb937855e6e84841899c33aa33699f9 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 5c7822f8f706e54b4bc803d61f08a631e1054af7 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 234f23f6495ad5864e6e74d6251eba6ebcd2a317 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 9d817e792c056d65c084f37304098fa4c9b4ebd5 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 8270658ccfa4b0d0d7b943d78ecec7696ac6c1de I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 23e4a74f29e8254d858ab2347de107db5dbdc497 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 787d30f3c49fa57675c3a9a23d1957a10ed48328 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 0d9f1dd4f20518e9db71d6fd55c7b0d38d108e17 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 45182fa63c7dce2d346fc332d1f607518d990dc5 I, Nikolas Schmitz , hereby add my Signed-off-by to this commit: 832c14e5afe2e6062eaa7263baee151d1b758c41 Signed-off-by: Nikolas Schmitz --- tests/testing_data/data_config.json | 5 +++++ tests/utils/enums/test_wsireader.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 79033dd0d6..69ec8f0813 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -5,6 +5,11 @@ "hash_type": "sha256", "hash_val": "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" }, + "wsi_generic_tiff_corrected": { + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1_correct_mpp.tiff", + "hash_type": "sha256", + "hash_val": "65306e3f8f7f5282d19d942dadc525cd06a80d5fd8268053939751365226c65f" + }, "wsi_aperio_svs": { "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/Aperio-CMU-1.svs", "hash_type": "sha256", diff --git a/tests/utils/enums/test_wsireader.py b/tests/utils/enums/test_wsireader.py index df06bf3ec7..0115e1caba 100644 --- a/tests/utils/enums/test_wsireader.py +++ b/tests/utils/enums/test_wsireader.py @@ -446,6 +446,12 @@ def setUpModule(): hash_type=testing_data_config("images", WSI_GENERIC_TIFF_KEY, "hash_type"), hash_val=testing_data_config("images", WSI_GENERIC_TIFF_KEY, "hash_val"), ) + download_url_or_skip_test( + testing_data_config("images", WSI_GENERIC_TIFF_CORRECT_MPP_KEY, "url"), + WSI_GENERIC_TIFF_CORRECT_MPP_PATH, + hash_type=testing_data_config("images", WSI_GENERIC_TIFF_CORRECT_MPP_KEY, "hash_type"), + hash_val=testing_data_config("images", WSI_GENERIC_TIFF_CORRECT_MPP_KEY, "hash_val"), + ) download_url_or_skip_test( testing_data_config("images", WSI_APERIO_SVS_KEY, "url"), WSI_APERIO_SVS_PATH, From 01e60e08c0a236ca173ec2dfb16043e238144412 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Wed, 12 Nov 2025 17:54:29 +0100 Subject: [PATCH 54/55] Minor changes --- monai/data/wsi_reader.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index f8db1a920f..4c545551d9 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -477,14 +477,14 @@ def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> boo upper_bound_y = mpp_closest_lvl_y * (1 + rtol) + atol # Check if user-provided mpp_x and mpp_y fall within the tolerance intervals for closest level - is_within_tolerance_x = (user_mpp_x >= lower_bound_x) & (user_mpp_x <= upper_bound_x) - is_within_tolerance_y = (user_mpp_y >= lower_bound_y) & (user_mpp_y <= upper_bound_y) - is_within_tolerance = is_within_tolerance_x & is_within_tolerance_y + is_within_tolerance_x = (user_mpp_x >= lower_bound_x) and (user_mpp_x <= upper_bound_x) + is_within_tolerance_y = (user_mpp_y >= lower_bound_y) and (user_mpp_y <= upper_bound_y) + is_within_tolerance = is_within_tolerance_x and is_within_tolerance_y # If mpp_closest_level < mpp -> closest_level has higher res than img at mpp => downscale from closest_level to mpp closest_level_is_bigger_x = mpp_closest_lvl_x < user_mpp_x closest_level_is_bigger_y = mpp_closest_lvl_y < user_mpp_y - closest_level_is_bigger = closest_level_is_bigger_x & closest_level_is_bigger_y + closest_level_is_bigger = closest_level_is_bigger_x and closest_level_is_bigger_y return is_within_tolerance, closest_level_is_bigger @@ -856,6 +856,8 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 else: # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. + if closest_lvl == 0: + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) closest_lvl = closest_lvl - 1 closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) From 4460f45c0ee4ad879e3b0062f0d82c6479feba79 Mon Sep 17 00:00:00 2001 From: Nikolas Schmitz Date: Fri, 14 Nov 2025 11:54:39 +0100 Subject: [PATCH 55/55] Fixed bug in get_wsi_at_mpp; added Return sections to docstrings Signed-off-by: Nikolas Schmitz --- monai/data/wsi_reader.py | 47 +++++++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index 4c545551d9..20528ae32d 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -442,6 +442,9 @@ def _compute_mpp_target_res(self, closest_lvl, closest_lvl_dim, mpp_list, mpp: t mpp_list: List of MPP values for all levels of the whole slide image. mpp: The MPP resolution at which the whole slide image representation should be extracted. + Returns: + Tuple of (target_res_x, target_res_y) representing the target pixel dimensions. + """ mpp_closest_lvl = mpp_list[closest_lvl] mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_closest_lvl @@ -466,6 +469,10 @@ def _compute_mpp_tolerances(self, closest_lvl, mpp_list, mpp, atol, rtol) -> boo atol: Absolute tolerance for MPP comparison. rtol: Relative tolerance for MPP comparison. + Returns: + Tuple of (is_within_tolerance, closest_level_is_bigger) where first element indicates + if MPP is within tolerance and second indicates if closest level has higher resolution. + """ user_mpp_x, user_mpp_y = mpp mpp_closest_lvl_x, mpp_closest_lvl_y = mpp_list[closest_lvl] @@ -676,6 +683,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 atol: the acceptable absolute tolerance for resolution in micro per pixel. rtol: the acceptable relative tolerance for resolution in micro per pixel. + Returns: + Numpy array containing the whole slide image at the requested MPP resolution. + """ return self.reader.get_wsi_at_mpp(wsi, mpp, atol, rtol) @@ -836,6 +846,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 atol: the acceptable absolute tolerance for resolution in micro per pixel. rtol: the acceptable relative tolerance for resolution in micro per pixel. + Returns: + Cupy array containing the whole slide image at the requested MPP resolution. + """ cp, _ = optional_import("cupy") @@ -858,8 +871,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. if closest_lvl == 0: closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = cp.asnumpy(closest_lvl_wsi) return wsi_arr @@ -958,6 +972,9 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. + Returns: + Resized cupy image array at the target MPP resolution. + """ cucim_resize, _ = optional_import("cucim.skimage.transform", name="resize") cp, _ = optional_import("cupy") @@ -1104,6 +1121,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 atol: the acceptable absolute tolerance for resolution in micro per pixel. rtol: the acceptable relative tolerance for resolution in micro per pixel. + Returns: + Numpy array containing the whole slide image at the requested MPP resolution. + """ mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(wsi.level_count)] @@ -1123,8 +1143,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 else: # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + if closest_lvl == 0: + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1209,6 +1232,9 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. + Returns: + PIL Image object resized to the target MPP resolution. + """ pil_image, _ = optional_import("PIL", name="Image") @@ -1341,6 +1367,9 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 atol: the acceptable absolute tolerance for resolution in micro per pixel. rtol: the acceptable relative tolerance for resolution in micro per pixel. + Returns: + Numpy array containing the whole slide image at the requested MPP resolution. + """ mpp_list = [self.get_mpp(wsi, lvl) for lvl in range(len(wsi.pages))] # Fails for some Tifffiles @@ -1358,8 +1387,11 @@ def get_wsi_at_mpp(self, wsi, mpp: tuple, atol: float = 0.00, rtol: float = 0.05 else: # If both checks fail, increase resolution (i.e., decrement level) and then downsample it. - closest_lvl = closest_lvl - 1 - closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + if closest_lvl == 0: + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) + else: + closest_lvl = closest_lvl - 1 + closest_lvl_wsi = self._resize_to_mpp_res(wsi, closest_lvl, mpp_list, mpp) wsi_arr = np.array(closest_lvl_wsi) return wsi_arr @@ -1451,6 +1483,9 @@ def _resize_to_mpp_res(self, wsi, closest_lvl, mpp_list, user_mpp: tuple): closest_lvl: the wsi level that is closest to the user-provided mpp resolution. mpp_list: list of mpp values for all levels of a whole slide image. + Returns: + PIL Image object resized to the target MPP resolution. + """ pil_image, _ = optional_import("PIL", name="Image")