Skip to content

Commit

Permalink
Better patches generation on deep learning segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tfmoraes committed Sep 23, 2024
1 parent 3d73c09 commit 9a26872
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions invesalius/segmentation/deep_learning/segment.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,52 @@
import itertools
import multiprocessing
import os
import pathlib
import sys
import tempfile
import traceback
from typing import Generator, Tuple

import numpy as np
from skimage.transform import resize
from vtkmodules.vtkIOXML import vtkXMLImageDataWriter

import invesalius.data.slice_ as slc
from invesalius import inv_paths
from invesalius.data import imagedata_utils
from invesalius.data.converters import to_vtk
from invesalius.net.utils import download_url_to_file
from invesalius.utils import new_name_by_pattern

from . import utils

SIZE = 48

patch_type = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]

def gen_patches(image, patch_size, overlap):

def gen_patches(
image: np.ndarray, patch_size: int, overlap: int
) -> Generator[Tuple[float, np.ndarray, patch_type], None, None]:
overlap = int(patch_size * overlap / 100)
sz, sy, sx = image.shape
i_cuts = list(
itertools.product(
range(0, sz, patch_size - overlap),
range(0, sy, patch_size - overlap),
range(0, sx, patch_size - overlap),
)
)
slices_x = [i for i in range(0, sx, patch_size - overlap) if i + patch_size <= sx]
if not slices_x:
slices_x.append(0)
elif slices_x[-1] + patch_size < sx:
slices_x.append(sx - patch_size)
slices_y = [i for i in range(0, sy, patch_size - overlap) if i + patch_size <= sy]
if not slices_y:
slices_y.append(0)
elif slices_y[-1] + patch_size < sy:
slices_y.append(sy - patch_size)
slices_z = [i for i in range(0, sz, patch_size - overlap) if i + patch_size <= sz]
if not slices_z:
slices_z.append(0)
elif slices_z[-1] + patch_size < sz:
slices_z.append(sz - patch_size)
i_cuts = list(itertools.product(slices_z, slices_y, slices_x))

sub_image = np.empty(shape=(patch_size, patch_size, patch_size), dtype="float32")
for idx, (iz, iy, ix) in enumerate(i_cuts):
sub_image[:] = 0
Expand Down

0 comments on commit 9a26872

Please sign in to comment.