From 9a26872321e5d84c322abedebfe6c40b1d040284 Mon Sep 17 00:00:00 2001 From: Thiago Franco de Moraes Date: Sun, 22 Sep 2024 22:10:59 -0300 Subject: [PATCH] Better patches generation on deep learning segmentation --- .../segmentation/deep_learning/segment.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/invesalius/segmentation/deep_learning/segment.py b/invesalius/segmentation/deep_learning/segment.py index 76a5e6d4d..9e84d8f53 100644 --- a/invesalius/segmentation/deep_learning/segment.py +++ b/invesalius/segmentation/deep_learning/segment.py @@ -1,15 +1,20 @@ import itertools import multiprocessing import os +import pathlib +import sys import tempfile import traceback +from typing import Generator, Tuple import numpy as np from skimage.transform import resize +from vtkmodules.vtkIOXML import vtkXMLImageDataWriter import invesalius.data.slice_ as slc from invesalius import inv_paths from invesalius.data import imagedata_utils +from invesalius.data.converters import to_vtk from invesalius.net.utils import download_url_to_file from invesalius.utils import new_name_by_pattern @@ -17,17 +22,31 @@ SIZE = 48 +patch_type = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]] -def gen_patches(image, patch_size, overlap): + +def gen_patches( + image: np.ndarray, patch_size: int, overlap: int +) -> Generator[Tuple[float, np.ndarray, patch_type], None, None]: overlap = int(patch_size * overlap / 100) sz, sy, sx = image.shape - i_cuts = list( - itertools.product( - range(0, sz, patch_size - overlap), - range(0, sy, patch_size - overlap), - range(0, sx, patch_size - overlap), - ) - ) + slices_x = [i for i in range(0, sx, patch_size - overlap) if i + patch_size <= sx] + if not slices_x: + slices_x.append(0) + elif slices_x[-1] + patch_size < sx: + slices_x.append(sx - patch_size) + slices_y = [i for i in range(0, sy, patch_size - overlap) if i + patch_size <= sy] + if not slices_y: + slices_y.append(0) + elif slices_y[-1] + patch_size < sy: + slices_y.append(sy - patch_size) + slices_z = [i for i in range(0, sz, patch_size - overlap) if i + patch_size <= sz] + if not slices_z: + slices_z.append(0) + elif slices_z[-1] + patch_size < sz: + slices_z.append(sz - patch_size) + i_cuts = list(itertools.product(slices_z, slices_y, slices_x)) + sub_image = np.empty(shape=(patch_size, patch_size, patch_size), dtype="float32") for idx, (iz, iy, ix) in enumerate(i_cuts): sub_image[:] = 0