Skip to content

Commit

Permalink
[utils] Implicitly convert unsigned types which are not supported by …
Browse files Browse the repository at this point in the history
…PyTorch
  • Loading branch information
aschuh-hf committed Aug 3, 2023
1 parent 28ccfa9 commit 03d763f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/deepali/utils/imageio/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def read_meta_image(arg: Union[PathUri, bytes, io.BufferedReader]) -> Tuple[Tens
spacing=spacing,
direction=matrix,
)
if data.dtype == np.uint16:
data = data.astype(np.int32)
elif data.dtype == np.uint32:
data = data.astype(np.int64)
return torch.from_numpy(data), grid


Expand Down
4 changes: 4 additions & 0 deletions src/deepali/utils/imageio/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def read_nifti_image(path: PathUri) -> Tuple[Tensor, Grid]:
grid = Grid(size=size, origin=origin, spacing=spacing, direction=direction)
if data.ndim == grid.ndim:
data = np.expand_dims(data, 0)
if data.dtype == np.uint16:
data = data.astype(np.int32)
elif data.dtype == np.uint32:
data = data.astype(np.int64)
return torch.from_numpy(data), grid


Expand Down
2 changes: 1 addition & 1 deletion src/deepali/utils/simpleitk/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def tensor_from_image(
r"""Create image data tensor from ``SimpleITK.Image``."""
if image.GetPixelID() == sitk.sitkUInt16:
image = sitk.Cast(image, sitk.sitkInt32)
elif image.GetPixelID() == sitk.sitkUInt16:
elif image.GetPixelID() == sitk.sitkUInt32:
image = sitk.Cast(image, sitk.sitkInt64)
data = torch.from_numpy(sitk.GetArrayFromImage(image))
data = data.unsqueeze(0)
Expand Down

0 comments on commit 03d763f

Please sign in to comment.