From 3ae403aa936f2ba0dd7a4bc4e83fac160fb882e3 Mon Sep 17 00:00:00 2001 From: zethson Date: Sat, 16 Nov 2024 13:14:20 +0100 Subject: [PATCH] filter_segmentation Signed-off-by: zethson --- .../mask_filtering/filter_segmentation.py | 262 ++++++++---------- 1 file changed, 115 insertions(+), 147 deletions(-) diff --git a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py index 10f3e6e7..2025d9a8 100644 --- a/src/scportrait/pipeline/mask_filtering/filter_segmentation.py +++ b/src/scportrait/pipeline/mask_filtering/filter_segmentation.py @@ -2,13 +2,15 @@ import gc import os import shutil -import sys import traceback from collections import defaultdict from multiprocessing import Pool +from pathlib import Path +from typing import Any import h5py import numpy as np +import numpy.typing as npt from alphabase.io import tempmmap from tqdm.auto import tqdm @@ -19,24 +21,19 @@ class SegmentationFilter(ProcessingStep): """SegmentationFilter helper class used for creating workflows to filter generated segmentation masks before extraction.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.identifier = None - self.window = None - self.input_path = None + self.identifier: int | None = None + self.window: tuple[slice, slice] | None = None + self.input_path: str | Path | None = None - def read_input_masks(self, input_path): - """ - Read input masks from a given HDF5 file. + def read_input_masks(self, input_path: str | Path) -> npt.NDArray[np.uint16]: + """Read input masks from a given HDF5 file. - Parameters - ---------- - input_path : str - Path to the HDF5 file containing the input masks. + Args: + input_path: Path to the HDF5 file containing the input masks. - Returns - ------- - numpy.ndarray + Returns: Array containing the input masks. """ with h5py.File(input_path, "r") as hf: @@ -49,53 +46,52 @@ def read_input_masks(self, input_path): input_masks = hdf_input[:2, :, :] return input_masks - def save_classes(self, classes): - """ - Save the filtered classes to a CSV file. + def save_classes(self, classes: dict[str, list[Any]]) -> None: + """Save the filtered classes to a CSV file. - Parameters - ---------- - classes : dict - Dictionary of classes to save. + Args: + classes: Dictionary of classes to save. """ - filtered_path = os.path.join(self.directory, self.DEFAULT_FILTERED_CLASSES_FILE) + filtered_path = Path(self.directory) / self.DEFAULT_REMOVED_CLASSES_FILE to_write = "\n".join([f"{str(x)}:{str(y)}" for x, y in classes.items()]) with open(filtered_path, "w") as myfile: myfile.write(to_write) self.log(f"Saved nucleus_id:cytosol_id matchings of all cells that passed filtering to {filtered_path}.") - def initialize_as_tile(self, identifier, window, input_path, zarr_status=True): - """ - Initialize Filtering Step with further parameters needed for filtering segmentation results. + def initialize_as_tile( + self, identifier: int, window: tuple[slice, slice], input_path: str | Path, zarr_status: bool = True + ) -> None: + """Initialize Filtering Step with further parameters needed for filtering segmentation results. Important: - This function is intended for internal use by the :class:`TiledFilterSegmentation` helper class. In most cases it is not relevant to the creation of custom filtering workflows. - - Parameters - ---------- - identifier : int - Unique index of the tile. - window : list of tuple - Defines the window which is assigned to the tile. The window will be applied to the input. The first element refers to the first dimension of the image and so on. - input_path : str - Location of the input HDF5 file. During tiled segmentation the :class:`TiledSegmentation` derived helper class will save the input image in form of a HDF5 file. - zarr_status : bool, optional - Status of zarr saving, by default True. + This function is intended for internal use by the :class:`TiledFilterSegmentation` helper class. + In most cases it is not relevant to the creation of custom filtering workflows. + + Args: + identifier: Unique index of the tile. + window: Defines the window which is assigned to the tile. The window will be applied + to the input. The first element refers to the first dimension of the image and so on. + input_path: Location of the input HDF5 file. During tiled segmentation the + :class:`TiledSegmentation` derived helper class will save the input image + in form of a HDF5 file. + zarr_status: Status of zarr saving, by default True. """ self.identifier = identifier self.window = window self.input_path = input_path self.save_zarr = zarr_status - def call_as_tile(self): - """ - Wrapper function for calling segmentation filtering on an image tile. + def call_as_tile(self) -> None: + """Wrapper function for calling segmentation filtering on an image tile. Important: - This function is intended for internal use by the :class:`TiledSegmentation` helper class. In most cases it is not relevant to the creation of custom segmentation workflows. + This function is intended for internal use by the :class:`TiledSegmentation` helper class. + In most cases it is not relevant to the creation of custom segmentation workflows. """ with h5py.File(self.input_path, "r") as hf: hdf_input = hf.get("labels") + if hdf_input is None: + raise ValueError("No 'labels' dataset found in HDF5 file") c, _, _ = hdf_input.shape x1 = self.window[0].start @@ -103,6 +99,9 @@ def call_as_tile(self): y1 = self.window[1].start y2 = self.window[1].stop + if any(v is None for v in [x1, x2, y1, y2]): + raise ValueError("Window slice boundaries cannot be None") + x = x2 - x1 y = y2 - y1 @@ -132,127 +131,106 @@ def call_as_tile(self): f.write(f"{self.window}\n") self.log(f"Filtering of tile with the slicing {self.window} finished.") - def get_output(self): - """ - Get the output file path. - - Returns - ------- - str - Path to the output file. - """ - return os.path.join(self.directory, self.DEFAULT_SEGMENTATION_FILE) + def get_output_path(self) -> Path: + return Path(self.directory) / self.DEFAULT_SEGMENTATION_FILE class TiledSegmentationFilter(SegmentationFilter): """TiledSegmentationFilter helper class used for creating workflows to filter generated segmentation masks using a tiled approach.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) if not hasattr(self, "method"): raise AttributeError("No SegmentationFilter method defined, please set attribute ``method``") + self.tile_directory: Path | None = None - def initialize_tile_list(self, tileing_plan, input_path): - """ - Initialize the list of tiles for segmentation filtering. + def initialize_tile_list( + self, tileing_plan: list[tuple[slice, slice]], input_path: str | Path + ) -> list[SegmentationFilter]: + """Initialize the list of tiles for segmentation filtering. - Parameters - ---------- - tileing_plan : list of tuple - List of windows defining the tiling plan. - input_path : str - Path to the input HDF5 file. + Args: + tileing_plan: List of windows defining the tiling plan. + input_path: Path to the input HDF5 file. - Returns - ------- - list + Returns: List of initialized tiles. """ _tile_list = [] self.input_path = input_path for i, window in enumerate(tileing_plan): - local_tile_directory = os.path.join(self.tile_directory, str(i)) - current_tile = self.method( + local_tile_directory = Path(self.tile_directory) / str(i) + current_tile = self.method( # type: ignore self.config, local_tile_directory, project_location=self.project_location, debug=self.debug, overwrite=self.overwrite, - intermediate_output=self.intermediate_output, + intermediate_output=self.intermediate_output, # type: ignore ) - current_tile.initialize_as_tile(i, window, self.input_path, zarr_status=False) + current_tile.initialize_as_tile(i, window, self.input_path, zarr_status=False) # type: ignore _tile_list.append(current_tile) return _tile_list - def initialize_tile_list_incomplete(self, tileing_plan, incomplete_indexes, input_path): - """ - Initialize the list of incomplete tiles for segmentation filtering. + def initialize_tile_list_incomplete( + self, tileing_plan: list[tuple[slice, slice]], incomplete_indexes: list[int], input_path: str | Path + ) -> list[SegmentationFilter]: + """Initialize the list of incomplete tiles for segmentation filtering. - Parameters - ---------- - tileing_plan : list of tuple - List of windows defining the tiling plan. - incomplete_indexes : list of int - List of indexes for incomplete tiles. - input_path : str - Path to the input HDF5 file. - - Returns - ------- - list + Args: + tileing_plan: List of windows defining the tiling plan. + incomplete_indexes: List of indexes for incomplete tiles. + input_path: Path to the input HDF5 file. + + Returns: List of initialized incomplete tiles. """ _tile_list = [] self.input_path = input_path for i, window in zip(incomplete_indexes, tileing_plan, strict=False): - local_tile_directory = os.path.join(self.tile_directory, str(i)) - current_tile = self.method( + local_tile_directory = Path(self.tile_directory) / str(i) + current_tile = self.method( # type: ignore self.config, local_tile_directory, project_location=self.project_location, debug=self.debug, overwrite=self.overwrite, - intermediate_output=self.intermediate_output, + intermediate_output=self.intermediate_output, # type: ignore ) - current_tile.initialize_as_tile(i, window, self.input_path, zarr_status=False) + current_tile.initialize_as_tile(i, window, self.input_path, zarr_status=False) # type: ignore _tile_list.append(current_tile) return _tile_list - def calculate_tileing_plan(self, mask_size): - """ - Calculate the tiling plan based on the mask size. + def calculate_tileing_plan(self, mask_size: tuple[int, int]) -> list[tuple[slice, slice]]: + """Calculate the tiling plan based on the mask size. - Parameters - ---------- - mask_size : tuple - Size of the mask. + Args: + mask_size: Size of the mask. - Returns - ------- - list of tuple + Returns: List of windows defining the tiling plan. """ tileing_plan_path = f"{self.directory}/tileing_plan.csv" - if os.path.isfile(tileing_plan_path): + if Path(tileing_plan_path).is_file(): self.log(f"tileing plan already found in directory {tileing_plan_path}.") if self.overwrite: self.log("Overwriting existing tileing plan.") - os.remove(tileing_plan_path) + Path(tileing_plan_path).unlink() else: self.log("Reading existing tileing plan from file.") - with open(tileing_plan_path) as f: - _tileing_plan = [eval(line) for line in f.readlines()] - return _tileing_plan + with Path(tileing_plan_path).open() as f: + tileing_plan = [eval(line) for line in f.readlines()] + return tileing_plan - _tileing_plan = [] + _tileing_plan: list[tuple[slice, slice]] = [] side_size = np.floor(np.sqrt(int(self.config["tile_size"]))) - tiles_side = np.round(mask_size / side_size).astype(int) - tile_size = mask_size // tiles_side + tiles_side = np.round(np.array(mask_size) / side_size).astype(int) + tile_size = np.array(mask_size) // tiles_side self.log(f"input image {mask_size[0]} px by {mask_size[1]} px") self.log(f"target_tile_size: {self.config['tile_size']}") @@ -285,30 +263,24 @@ def calculate_tileing_plan(self, mask_size): return _tileing_plan - def execute_tile_list(self, tile_list, n_cpu=None): - """ - Execute the filtering process for a list of tiles. - - Parameters - ---------- - tile_list : list - List of tiles to process. - n_cpu : int, optional - Number of CPU cores to use, by default None. - - Returns - ------- - list + def execute_tile_list(self, tile_list: list[SegmentationFilter], n_cpu: int | None = None) -> list[Path]: + """Execute the filtering process for a list of tiles. + + Args: + tile_list: List of tiles to process. + n_cpu: Number of CPU cores to use, by default None. + + Returns: List of output file paths for the processed tiles. """ - def f(x): + def f(x: SegmentationFilter) -> Path: try: x.call_as_tile() except (OSError, ValueError, RuntimeError) as e: self.log(f"An error occurred: {e}") self.log(traceback.format_exc()) - return x.get_output() + return x.get_output_path() if n_cpu == 1: self.log(f"Running sequentially on {n_cpu} CPU") @@ -319,54 +291,48 @@ def f(x): with Pool(n_processes) as pool: return list(pool.imap(f, tile_list)) - def execute_tile(self, tile): - """ - Execute the filtering process for a single tile. + def execute_tile(self, tile: SegmentationFilter) -> Path: + """Execute the filtering process for a single tile. - Parameters - ---------- - tile : object - Tile to process. + Args: + tile: Tile to process. - Returns - ------- - str + Returns: Output file path for the processed tile. """ tile.call_as_tile() - return tile.get_output() + return tile.get_output_path() - def initialize_tile_directory(self): - """ - Initialize the directory for storing tile outputs. - """ - self.tile_directory = os.path.join(self.directory, self.DEFAULT_TILES_FOLDER) - if os.path.exists(self.tile_directory): + def initialize_tile_directory(self) -> None: + """Initialize the directory for storing tile outputs.""" + self.tile_directory = Path(self.directory) / self.DEFAULT_TILES_FOLDER + if self.tile_directory.exists(): self.log(f"Directory {self.tile_directory} already exists.") if self.overwrite: self.log("Overwriting existing tiles folder.") shutil.rmtree(self.tile_directory) + self.tile_directory.mkdir() + else: os.makedirs(self.tile_directory) else: - os.makedirs(self.tile_directory) + self.tile_directory.mkdir() - def collect_results(self): - """ - Collect the results from the processed tiles. + def collect_results(self) -> npt.NDArray[np.uint16]: + """Collect the results from the processed tiles. - Returns - ------- - numpy.ndarray + Returns: Array containing the combined results from all tiles. """ self.log("Reading in all tile results") with h5py.File(self.input_path, "r") as hf: hdf_input = hf.get("labels") + if hdf_input is None: + raise ValueError("No 'labels' dataset found in HDF5 file") c, y, x = hdf_input.shape self.log(f"Output image will have shape {c, y, x}") output_image = np.zeros((c, y, x), dtype=np.uint16) - classes = defaultdict(list) + classes: defaultdict[str, list[Any]] = defaultdict(list) with open(f"{self.directory}/window.csv") as f: _window_locations = [eval(line.strip()) for line in f.readlines()] @@ -377,6 +343,8 @@ def collect_results(self): out_dir = os.path.join(self.tile_directory, str(i)) with h5py.File(f"{out_dir}/segmentation.h5", "r") as hf: data = hf.get("labels") + if data is None: + raise ValueError(f"No 'labels' dataset found in tile {i}") for cls, mappings in csv.reader(open(f"{out_dir}/filtered_classes.csv")): classes[cls].append(mappings) output_image[:, loc[0], loc[1]] = data[:, :]