Skip to content

Commit

Permalink
Address some typing errors (#1253)
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar authored Dec 19, 2024
1 parent 6596899 commit 795eb1f
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ def get_bounds(self) -> TypeBounds:
first_point = apply_affine(self.affine, first_index)
last_point = apply_affine(self.affine, last_index)
array = np.array((first_point, last_point))
bounds_x, bounds_y, bounds_z = array.T.tolist()
return bounds_x, bounds_y, bounds_z
bounds_x, bounds_y, bounds_z = array.T.tolist() # type: ignore[misc]
return bounds_x, bounds_y, bounds_z # type: ignore[return-value]

@staticmethod
def _parse_single_path(
Expand Down
30 changes: 17 additions & 13 deletions src/torchio/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import nibabel as nib
import numpy as np
import numpy.typing as npt
import SimpleITK as sitk
import torch
from nibabel.filebasedimages import ImageFileError
from nibabel.spatialimages import SpatialImage

from ..constants import REPO_URL
Expand Down Expand Up @@ -36,7 +38,7 @@ def read_image(path: TypePath) -> TypeDataAffine:
warnings.warn(message, stacklevel=2)
try:
result = _read_nibabel(path)
except nib.loadsave.ImageFileError as e:
except ImageFileError as e:
message = (
f'File "{path}" not understood.'
' Check supported formats by at'
Expand All @@ -56,6 +58,7 @@ def _read_nibabel(path: TypePath) -> TypeDataAffine:
data = check_uint_to_int(data)
tensor = torch.as_tensor(data)
affine = img.affine
assert isinstance(affine, np.ndarray)
return tensor, affine


Expand Down Expand Up @@ -155,19 +158,19 @@ def _write_nibabel(
else:
tensor = tensor[np.newaxis].permute(2, 3, 4, 0, 1)
suffix = Path(str(path).replace('.gz', '')).suffix
img: Union[nib.Nifti1Image, nib.Nifti1Pair]
img: Union[nib.nifti1.Nifti1Image, nib.nifti1.Nifti1Pair]
if '.nii' in suffix:
img = nib.Nifti1Image(np.asarray(tensor), affine)
img = nib.nifti1.Nifti1Image(np.asarray(tensor), affine)
elif '.hdr' in suffix or '.img' in suffix:
img = nib.Nifti1Pair(np.asarray(tensor), affine)
img = nib.nifti1.Nifti1Pair(np.asarray(tensor), affine)
else:
raise nib.loadsave.ImageFileError
assert isinstance(img.header, nib.Nifti1Header)
raise ImageFileError
assert isinstance(img.header, nib.nifti1.Nifti1Header)
if num_components > 1:
img.header.set_intent('vector')
img.header['qform_code'] = 1
img.header['sform_code'] = 0
nib.save(img, str(path))
nib.loadsave.save(img, str(path))


def _write_sitk(
Expand Down Expand Up @@ -269,9 +272,9 @@ def _matrix_to_itk_transform(

def _read_niftyreg_matrix(trsf_path: TypePath) -> torch.Tensor:
"""Read a NiftyReg matrix and return it as a NumPy array."""
matrix = np.loadtxt(trsf_path)
matrix = np.linalg.inv(matrix)
return torch.as_tensor(matrix)
read_matrix = np.loadtxt(trsf_path).astype(np.float64)
inverted = np.linalg.inv(read_matrix)
return torch.from_numpy(inverted)


def _write_niftyreg_matrix(matrix: TypeData, txt_path: TypePath) -> None:
Expand Down Expand Up @@ -361,10 +364,11 @@ def sitk_to_nib(
def get_ras_affine_from_sitk(
sitk_object: Union[sitk.Image, sitk.ImageFileReader],
) -> np.ndarray:
spacing = np.array(sitk_object.GetSpacing())
direction_lps = np.array(sitk_object.GetDirection())
origin_lps = np.array(sitk_object.GetOrigin())
spacing = np.array(sitk_object.GetSpacing(), dtype=np.float64)
direction_lps = np.array(sitk_object.GetDirection(), dtype=np.float64)
origin_lps = np.array(sitk_object.GetOrigin(), dtype=np.float64)
direction_length = len(direction_lps)
rotation_lps: npt.NDArray[np.float64]
if direction_length == 9:
rotation_lps = direction_lps.reshape(3, 3)
elif direction_length == 4: # ignore last dimension if 2D (1, W, H, 1)
Expand Down
2 changes: 1 addition & 1 deletion src/torchio/data/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_crop_transform(
crop_ini = index_ini_array.tolist()
crop_fin = (shape - index_fin).tolist()
start = ()
cropping = sum(zip(crop_ini, crop_fin), start)
cropping = sum(zip(crop_ini, crop_fin), start) # type: ignore[arg-type]
return Crop(cropping) # type: ignore[arg-type]

def __call__(
Expand Down
2 changes: 1 addition & 1 deletion src/torchio/data/sampler/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def clear_probability_borders(

# The call tolist() is very important. Using np.uint16 as negative
# index will not work because e.g. -np.uint16(2) == 65534
crop_i, crop_j, crop_k = crop_fin.tolist()
crop_i, crop_j, crop_k = crop_fin.tolist() # type: ignore[misc]
if crop_i:
probability_map[-crop_i:, :, :] = 0
if crop_j:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def add_artifact(
self.sort_spectra(spectra, times)
result_spectrum = torch.empty_like(spectra[0])
last_index = result_spectrum.shape[2]
indices = (last_index * times).astype(int).tolist()
indices_array = (last_index * times).astype(int)
indices: list[int] = indices_array.tolist() # type: ignore[assignment]
indices.append(last_index)
ini = 0
for spectrum, fin in zip(spectra, indices):
Expand Down
4 changes: 2 additions & 2 deletions src/torchio/transforms/augmentation/intensity/random_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def get_params(
for _ in range(num_iterations):
first_ini, first_fin = get_random_indices_from_shape(
spatial_shape,
patch_size.tolist(),
patch_size.tolist(), # type: ignore[arg-type]
)
while True:
second_ini, second_fin = get_random_indices_from_shape(
spatial_shape,
patch_size.tolist(),
patch_size.tolist(), # type: ignore[arg-type]
)
larger_than_initial = np.all(second_ini >= first_ini)
less_than_final = np.all(second_fin <= first_fin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def ras_to_lps(triplet: Sequence[float]):
radians = np.radians(degrees).tolist()

# SimpleITK uses LPS
radians_lps = ras_to_lps(radians)
radians_lps = ras_to_lps(radians) # type: ignore[arg-type]
translation_lps = ras_to_lps(translation)

transform.SetRotation(*radians_lps)
Expand Down
6 changes: 3 additions & 3 deletions src/torchio/transforms/augmentation/spatial/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def apply_transform(self, subject: Subject) -> Subject:
if i not in potential_axes:
axes_to_flip_hot[i] = False
(axes,) = np.where(axes_to_flip_hot)
axes = axes.tolist()
if not axes:
axes_list = axes.tolist()
if not axes_list:
return subject

arguments = {'axes': axes}
arguments = {'axes': axes_list}
transform = Flip(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
Expand Down

0 comments on commit 795eb1f

Please sign in to comment.