diff --git a/src/deepali/utils/imageio/meta.py b/src/deepali/utils/imageio/meta.py index 7168932..260874c 100644 --- a/src/deepali/utils/imageio/meta.py +++ b/src/deepali/utils/imageio/meta.py @@ -64,6 +64,10 @@ def read_meta_image(arg: Union[PathUri, bytes, io.BufferedReader]) -> Tuple[Tens spacing=spacing, direction=matrix, ) + if data.dtype == np.uint16: + data = data.astype(np.int32) + elif data.dtype == np.uint32: + data = data.astype(np.int64) return torch.from_numpy(data), grid diff --git a/src/deepali/utils/imageio/nifti.py b/src/deepali/utils/imageio/nifti.py index d793e4e..f821809 100644 --- a/src/deepali/utils/imageio/nifti.py +++ b/src/deepali/utils/imageio/nifti.py @@ -88,6 +88,10 @@ def read_nifti_image(path: PathUri) -> Tuple[Tensor, Grid]: grid = Grid(size=size, origin=origin, spacing=spacing, direction=direction) if data.ndim == grid.ndim: data = np.expand_dims(data, 0) + if data.dtype == np.uint16: + data = data.astype(np.int32) + elif data.dtype == np.uint32: + data = data.astype(np.int64) return torch.from_numpy(data), grid diff --git a/src/deepali/utils/simpleitk/torch.py b/src/deepali/utils/simpleitk/torch.py index cda93bd..c5df5b7 100644 --- a/src/deepali/utils/simpleitk/torch.py +++ b/src/deepali/utils/simpleitk/torch.py @@ -48,7 +48,7 @@ def tensor_from_image( r"""Create image data tensor from ``SimpleITK.Image``.""" if image.GetPixelID() == sitk.sitkUInt16: image = sitk.Cast(image, sitk.sitkInt32) - elif image.GetPixelID() == sitk.sitkUInt16: + elif image.GetPixelID() == sitk.sitkUInt32: image = sitk.Cast(image, sitk.sitkInt64) data = torch.from_numpy(sitk.GetArrayFromImage(image)) data = data.unsqueeze(0)