Skip to content

Commit d36f0c8

Browse files
yiheng-wang-nvpre-commit-ci[bot]ericspodKumoLiu
authored
enable gpu load nifti (Project-MONAI#8188)
Related to Project-MONAI#8241 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <vennw@nvidia.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent efff647 commit d36f0c8

File tree

5 files changed

+136
-12
lines changed

5 files changed

+136
-12
lines changed

monai/data/image_reader.py

+77-9
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from __future__ import annotations
1313

1414
import glob
15+
import gzip
16+
import io
1517
import os
1618
import re
19+
import tempfile
1720
import warnings
1821
from abc import ABC, abstractmethod
1922
from collections.abc import Callable, Iterable, Iterator, Sequence
@@ -51,6 +54,9 @@
5154
pydicom, has_pydicom = optional_import("pydicom")
5255
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)
5356

57+
cp, has_cp = optional_import("cupy")
58+
kvikio, has_kvikio = optional_import("kvikio")
59+
5460
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
5561

5662

@@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
137143
)
138144

139145

140-
def _stack_images(image_list: list, meta_dict: dict):
146+
def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
141147
if len(image_list) <= 1:
142148
return image_list[0]
143149
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
144150
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
151+
if to_cupy and has_cp:
152+
return cp.concatenate(image_list, axis=channel_dim)
145153
return np.concatenate(image_list, axis=channel_dim)
146154
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
147155
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
156+
if to_cupy and has_cp:
157+
return cp.stack(image_list, axis=0)
148158
return np.stack(image_list, axis=0)
149159

150160

@@ -864,12 +874,18 @@ class NibabelReader(ImageReader):
864874
Load NIfTI format images based on Nibabel library.
865875
866876
Args:
867-
as_closest_canonical: if True, load the image as closest to canonical axis format.
868-
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
869877
channel_dim: the channel dimension of the input image, default is None.
870878
this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
871879
if None, `original_channel_dim` will be either `no_channel` or `-1`.
872880
most Nifti files are usually "channel last", no need to specify this argument for them.
881+
as_closest_canonical: if True, load the image as closest to canonical axis format.
882+
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
883+
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
884+
Default is False. CuPy and Kvikio are required for this option.
885+
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
886+
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887+
In practical use, it's recommended to add a warm up call before the actual loading.
888+
A related tutorial will be prepared in the future, and the document will be updated accordingly.
873889
kwargs: additional args for `nibabel.load` API. more details about available args:
874890
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
875891
@@ -880,14 +896,42 @@ def __init__(
880896
channel_dim: str | int | None = None,
881897
as_closest_canonical: bool = False,
882898
squeeze_non_spatial_dims: bool = False,
899+
to_gpu: bool = False,
883900
**kwargs,
884901
):
885902
super().__init__()
886903
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
887904
self.as_closest_canonical = as_closest_canonical
888905
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
906+
if to_gpu and (not has_cp or not has_kvikio):
907+
warnings.warn(
908+
"NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
909+
)
910+
to_gpu = False
911+
912+
if to_gpu:
913+
self.warmup_kvikio()
914+
915+
self.to_gpu = to_gpu
889916
self.kwargs = kwargs
890917

918+
def warmup_kvikio(self):
919+
"""
920+
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
921+
This can accelerate the data loading process when `to_gpu` is set to True.
922+
"""
923+
if has_cp and has_kvikio:
924+
a = cp.arange(100)
925+
with tempfile.NamedTemporaryFile() as tmp_file:
926+
tmp_file_name = tmp_file.name
927+
f = kvikio.CuFile(tmp_file_name, "w")
928+
f.write(a)
929+
f.close()
930+
931+
b = cp.empty_like(a)
932+
f = kvikio.CuFile(tmp_file_name, "r")
933+
f.read(b)
934+
891935
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
892936
"""
893937
Verify whether the specified file or files format is supported by Nibabel reader.
@@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
916960
img_: list[Nifti1Image] = []
917961

918962
filenames: Sequence[PathLike] = ensure_tuple(data)
963+
self.filenames = filenames
919964
kwargs_ = self.kwargs.copy()
920965
kwargs_.update(kwargs)
921966
for name in filenames:
@@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
936981
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
937982
938983
"""
984+
# TODO: the actual type is list[np.ndarray | cp.ndarray]
985+
# should figure out how to define correct types without having cupy not found error
986+
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
939987
img_array: list[np.ndarray] = []
940988
compatible_meta: dict = {}
941989

942-
for i in ensure_tuple(img):
990+
for i, filename in zip(ensure_tuple(img), self.filenames):
943991
header = self._get_meta_dict(i)
944992
header[MetaKeys.AFFINE] = self._get_affine(i)
945993
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
@@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
949997
header[MetaKeys.AFFINE] = self._get_affine(i)
950998
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
951999
header[MetaKeys.SPACE] = SpaceKeys.RAS
952-
data = self._get_array_data(i)
1000+
data = self._get_array_data(i, filename)
9531001
if self.squeeze_non_spatial_dims:
9541002
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
9551003
if data.shape[d - 1] == 1:
@@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
9631011
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
9641012
_copy_compatible_dict(header, compatible_meta)
9651013

966-
return _stack_images(img_array, compatible_meta), compatible_meta
1014+
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta
9671015

9681016
def _get_meta_dict(self, img) -> dict:
9691017
"""
@@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img):
10151063
spatial_rank = max(min(ndim, 3), 1)
10161064
return np.asarray(size[:spatial_rank])
10171065

1018-
def _get_array_data(self, img):
1066+
def _get_array_data(self, img, filename):
10191067
"""
10201068
Get the raw array data of the image, converted to Numpy array.
10211069
10221070
Args:
10231071
img: a Nibabel image object loaded from an image file.
1024-
1025-
"""
1072+
filename: file name of the image.
1073+
1074+
"""
1075+
if self.to_gpu:
1076+
file_size = os.path.getsize(filename)
1077+
image = cp.empty(file_size, dtype=cp.uint8)
1078+
with kvikio.CuFile(filename, "r") as f:
1079+
f.read(image)
1080+
if filename.endswith(".nii.gz"):
1081+
# for compressed data, have to tansfer to CPU to decompress
1082+
# and then transfer back to GPU. It is not efficient compared to .nii file
1083+
# and may be slower than CPU loading in some cases.
1084+
warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.")
1085+
compressed_data = cp.asnumpy(image)
1086+
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
1087+
decompressed_data = gz_file.read()
1088+
1089+
image = cp.frombuffer(decompressed_data, dtype=cp.uint8)
1090+
data_shape = img.shape
1091+
data_offset = img.dataobj.offset
1092+
data_dtype = img.dataobj.dtype
1093+
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
10261094
return np.asanyarray(img.dataobj, order="C")
10271095

10281096

monai/data/meta_tensor.py

-1
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,6 @@ def ensure_torch_and_prune_meta(
553553
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
554554
"""
555555
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray
556-
557556
# if not tracking metadata, return `torch.Tensor`
558557
if not isinstance(img, MetaTensor):
559558
return img

monai/transforms/io/array.py

-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
286286
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
287287
f" The current registered: {self.readers}.\n{msg}"
288288
)
289-
290289
img_array: NdarrayOrTensor
291290
img_array, meta_data = reader.get_data(img)
292291
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]

tests/test_init_reader.py

+19
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ def test_load_image(self):
3030
inst = LoadImaged("image", reader=r)
3131
self.assertIsInstance(inst, LoadImaged)
3232

33+
@SkipIfNoModule("nibabel")
34+
@SkipIfNoModule("cupy")
35+
@SkipIfNoModule("kvikio")
36+
def test_load_image_to_gpu(self):
37+
for to_gpu in [True, False]:
38+
instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu)
39+
self.assertIsInstance(instance1, LoadImage)
40+
41+
instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu)
42+
self.assertIsInstance(instance2, LoadImaged)
43+
3344
@SkipIfNoModule("itk")
3445
@SkipIfNoModule("nibabel")
3546
@SkipIfNoModule("PIL")
@@ -58,6 +69,14 @@ def test_readers(self):
5869
inst = NrrdReader()
5970
self.assertIsInstance(inst, NrrdReader)
6071

72+
@SkipIfNoModule("nibabel")
73+
@SkipIfNoModule("cupy")
74+
@SkipIfNoModule("kvikio")
75+
def test_readers_to_gpu(self):
76+
for to_gpu in [True, False]:
77+
inst = NibabelReader(to_gpu=to_gpu)
78+
self.assertIsInstance(inst, NibabelReader)
79+
6180

6281
if __name__ == "__main__":
6382
unittest.main()

tests/test_load_image.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from monai.data.meta_tensor import MetaTensor
3030
from monai.transforms import LoadImage
3131
from monai.utils import optional_import
32-
from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config
32+
from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config
3333

3434
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
3535
ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator")
@@ -74,6 +74,22 @@ def get_data(self, _obj):
7474

7575
TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)]
7676

77+
TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)]
78+
79+
TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)]
80+
81+
TEST_CASE_GPU_3 = [
82+
{"reader": "nibabelreader", "to_gpu": True},
83+
["test_image.nii", "test_image2.nii", "test_image3.nii"],
84+
(3, 128, 128, 128),
85+
]
86+
87+
TEST_CASE_GPU_4 = [
88+
{"reader": "nibabelreader", "to_gpu": True},
89+
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
90+
(3, 128, 128, 128),
91+
]
92+
7793
TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]
7894

7995
TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]
@@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape):
196212
assert_allclose(result.affine, torch.eye(4))
197213
self.assertTupleEqual(result.shape, expected_shape)
198214

215+
@SkipIfNoModule("nibabel")
216+
@SkipIfNoModule("cupy")
217+
@SkipIfNoModule("kvikio")
218+
@parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4])
219+
def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):
220+
test_image = np.random.rand(128, 128, 128)
221+
with tempfile.TemporaryDirectory() as tempdir:
222+
for i, name in enumerate(filenames):
223+
filenames[i] = os.path.join(tempdir, name)
224+
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
225+
result = LoadImage(image_only=True, **input_param)(filenames)
226+
ext = "".join(Path(name).suffixes)
227+
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext))
228+
self.assertEqual(result.meta["space"], "RAS")
229+
assert_allclose(result.affine, torch.eye(4))
230+
self.assertTupleEqual(result.shape, expected_shape)
231+
232+
# verify gpu and cpu loaded data are the same
233+
input_param_cpu = input_param.copy()
234+
input_param_cpu["to_gpu"] = False
235+
result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames)
236+
self.assertTrue(torch.equal(result_cpu, result.cpu()))
237+
199238
@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9])
200239
def test_itk_reader(self, input_param, filenames, expected_shape):
201240
test_image = np.random.rand(128, 128, 128)

0 commit comments

Comments
 (0)