From 795eb1f2e7d2e7f5bb85b0c8b6593bf4f57a1f90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Thu, 19 Dec 2024 23:34:04 +0000 Subject: [PATCH] Address some typing errors (#1253) --- src/torchio/data/image.py | 4 +-- src/torchio/data/io.py | 30 +++++++++++-------- src/torchio/data/sampler/sampler.py | 2 +- src/torchio/data/sampler/weighted.py | 2 +- .../augmentation/intensity/random_motion.py | 3 +- .../augmentation/intensity/random_swap.py | 4 +-- .../augmentation/spatial/random_affine.py | 2 +- .../augmentation/spatial/random_flip.py | 6 ++-- 8 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/torchio/data/image.py b/src/torchio/data/image.py index e4f010bd..aba6c15f 100644 --- a/src/torchio/data/image.py +++ b/src/torchio/data/image.py @@ -450,8 +450,8 @@ def get_bounds(self) -> TypeBounds: first_point = apply_affine(self.affine, first_index) last_point = apply_affine(self.affine, last_index) array = np.array((first_point, last_point)) - bounds_x, bounds_y, bounds_z = array.T.tolist() - return bounds_x, bounds_y, bounds_z + bounds_x, bounds_y, bounds_z = array.T.tolist() # type: ignore[misc] + return bounds_x, bounds_y, bounds_z # type: ignore[return-value] @staticmethod def _parse_single_path( diff --git a/src/torchio/data/io.py b/src/torchio/data/io.py index b1e05f0c..4d5ee499 100644 --- a/src/torchio/data/io.py +++ b/src/torchio/data/io.py @@ -5,8 +5,10 @@ import nibabel as nib import numpy as np +import numpy.typing as npt import SimpleITK as sitk import torch +from nibabel.filebasedimages import ImageFileError from nibabel.spatialimages import SpatialImage from ..constants import REPO_URL @@ -36,7 +38,7 @@ def read_image(path: TypePath) -> TypeDataAffine: warnings.warn(message, stacklevel=2) try: result = _read_nibabel(path) - except nib.loadsave.ImageFileError as e: + except ImageFileError as e: message = ( f'File "{path}" not understood.' ' Check supported formats by at' @@ -56,6 +58,7 @@ def _read_nibabel(path: TypePath) -> TypeDataAffine: data = check_uint_to_int(data) tensor = torch.as_tensor(data) affine = img.affine + assert isinstance(affine, np.ndarray) return tensor, affine @@ -155,19 +158,19 @@ def _write_nibabel( else: tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1) suffix = Path(str(path).replace('.gz', '')).suffix - img: Union[nib.Nifti1Image, nib.Nifti1Pair] + img: Union[nib.nifti1.Nifti1Image, nib.nifti1.Nifti1Pair] if '.nii' in suffix: - img = nib.Nifti1Image(np.asarray(tensor), affine) + img = nib.nifti1.Nifti1Image(np.asarray(tensor), affine) elif '.hdr' in suffix or '.img' in suffix: - img = nib.Nifti1Pair(np.asarray(tensor), affine) + img = nib.nifti1.Nifti1Pair(np.asarray(tensor), affine) else: - raise nib.loadsave.ImageFileError - assert isinstance(img.header, nib.Nifti1Header) + raise ImageFileError + assert isinstance(img.header, nib.nifti1.Nifti1Header) if num_components > 1: img.header.set_intent('vector') img.header['qform_code'] = 1 img.header['sform_code'] = 0 - nib.save(img, str(path)) + nib.loadsave.save(img, str(path)) def _write_sitk( @@ -269,9 +272,9 @@ def _matrix_to_itk_transform( def _read_niftyreg_matrix(trsf_path: TypePath) -> torch.Tensor: """Read a NiftyReg matrix and return it as a NumPy array.""" - matrix = np.loadtxt(trsf_path) - matrix = np.linalg.inv(matrix) - return torch.as_tensor(matrix) + read_matrix = np.loadtxt(trsf_path).astype(np.float64) + inverted = np.linalg.inv(read_matrix) + return torch.from_numpy(inverted) def _write_niftyreg_matrix(matrix: TypeData, txt_path: TypePath) -> None: @@ -361,10 +364,11 @@ def sitk_to_nib( def get_ras_affine_from_sitk( sitk_object: Union[sitk.Image, sitk.ImageFileReader], ) -> np.ndarray: - spacing = np.array(sitk_object.GetSpacing()) - direction_lps = np.array(sitk_object.GetDirection()) - origin_lps = np.array(sitk_object.GetOrigin()) + spacing = np.array(sitk_object.GetSpacing(), dtype=np.float64) + direction_lps = np.array(sitk_object.GetDirection(), dtype=np.float64) + origin_lps = np.array(sitk_object.GetOrigin(), dtype=np.float64) direction_length = len(direction_lps) + rotation_lps: npt.NDArray[np.float64] if direction_length == 9: rotation_lps = direction_lps.reshape(3, 3) elif direction_length == 4: # ignore last dimension if 2D (1, W, H, 1) diff --git a/src/torchio/data/sampler/sampler.py b/src/torchio/data/sampler/sampler.py index 540e168b..9d65290d 100644 --- a/src/torchio/data/sampler/sampler.py +++ b/src/torchio/data/sampler/sampler.py @@ -76,7 +76,7 @@ def _get_crop_transform( crop_ini = index_ini_array.tolist() crop_fin = (shape - index_fin).tolist() start = () - cropping = sum(zip(crop_ini, crop_fin), start) + cropping = sum(zip(crop_ini, crop_fin), start) # type: ignore[arg-type] return Crop(cropping) # type: ignore[arg-type] def __call__( diff --git a/src/torchio/data/sampler/weighted.py b/src/torchio/data/sampler/weighted.py index 81562a38..e62e86e0 100644 --- a/src/torchio/data/sampler/weighted.py +++ b/src/torchio/data/sampler/weighted.py @@ -156,7 +156,7 @@ def clear_probability_borders( # The call tolist() is very important. Using np.uint16 as negative # index will not work because e.g. -np.uint16(2) == 65534 - crop_i, crop_j, crop_k = crop_fin.tolist() + crop_i, crop_j, crop_k = crop_fin.tolist() # type: ignore[misc] if crop_i: probability_map[-crop_i:, :, :] = 0 if crop_j: diff --git a/src/torchio/transforms/augmentation/intensity/random_motion.py b/src/torchio/transforms/augmentation/intensity/random_motion.py index edabc722..7411fe7c 100644 --- a/src/torchio/transforms/augmentation/intensity/random_motion.py +++ b/src/torchio/transforms/augmentation/intensity/random_motion.py @@ -287,7 +287,8 @@ def add_artifact( self.sort_spectra(spectra, times) result_spectrum = torch.empty_like(spectra[0]) last_index = result_spectrum.shape[2] - indices = (last_index * times).astype(int).tolist() + indices_array = (last_index * times).astype(int) + indices: list[int] = indices_array.tolist() # type: ignore[assignment] indices.append(last_index) ini = 0 for spectrum, fin in zip(spectra, indices): diff --git a/src/torchio/transforms/augmentation/intensity/random_swap.py b/src/torchio/transforms/augmentation/intensity/random_swap.py index b2076a68..874aa646 100644 --- a/src/torchio/transforms/augmentation/intensity/random_swap.py +++ b/src/torchio/transforms/augmentation/intensity/random_swap.py @@ -68,12 +68,12 @@ def get_params( for _ in range(num_iterations): first_ini, first_fin = get_random_indices_from_shape( spatial_shape, - patch_size.tolist(), + patch_size.tolist(), # type: ignore[arg-type] ) while True: second_ini, second_fin = get_random_indices_from_shape( spatial_shape, - patch_size.tolist(), + patch_size.tolist(), # type: ignore[arg-type] ) larger_than_initial = np.all(second_ini >= first_ini) less_than_final = np.all(second_fin <= first_fin) diff --git a/src/torchio/transforms/augmentation/spatial/random_affine.py b/src/torchio/transforms/augmentation/spatial/random_affine.py index 93797ea2..3a00a1ab 100644 --- a/src/torchio/transforms/augmentation/spatial/random_affine.py +++ b/src/torchio/transforms/augmentation/spatial/random_affine.py @@ -302,7 +302,7 @@ def ras_to_lps(triplet: Sequence[float]): radians = np.radians(degrees).tolist() # SimpleITK uses LPS - radians_lps = ras_to_lps(radians) + radians_lps = ras_to_lps(radians) # type: ignore[arg-type] translation_lps = ras_to_lps(translation) transform.SetRotation(*radians_lps) diff --git a/src/torchio/transforms/augmentation/spatial/random_flip.py b/src/torchio/transforms/augmentation/spatial/random_flip.py index e11882ed..0fcd6f23 100644 --- a/src/torchio/transforms/augmentation/spatial/random_flip.py +++ b/src/torchio/transforms/augmentation/spatial/random_flip.py @@ -53,11 +53,11 @@ def apply_transform(self, subject: Subject) -> Subject: if i not in potential_axes: axes_to_flip_hot[i] = False (axes,) = np.where(axes_to_flip_hot) - axes = axes.tolist() - if not axes: + axes_list = axes.tolist() + if not axes_list: return subject - arguments = {'axes': axes} + arguments = {'axes': axes_list} transform = Flip(**self.add_include_exclude(arguments)) transformed = transform(subject) assert isinstance(transformed, Subject)