12
12
from __future__ import annotations
13
13
14
14
import glob
15
+ import gzip
16
+ import io
15
17
import os
16
18
import re
19
+ import tempfile
17
20
import warnings
18
21
from abc import ABC , abstractmethod
19
22
from collections .abc import Callable , Iterable , Iterator , Sequence
51
54
pydicom , has_pydicom = optional_import ("pydicom" )
52
55
nrrd , has_nrrd = optional_import ("nrrd" , allow_namespace_pkg = True )
53
56
57
+ cp , has_cp = optional_import ("cupy" )
58
+ kvikio , has_kvikio = optional_import ("kvikio" )
59
+
54
60
__all__ = ["ImageReader" , "ITKReader" , "NibabelReader" , "NumpyReader" , "PILReader" , "PydicomReader" , "NrrdReader" ]
55
61
56
62
@@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
137
143
)
138
144
139
145
140
- def _stack_images (image_list : list , meta_dict : dict ):
146
+ def _stack_images (image_list : list , meta_dict : dict , to_cupy : bool = False ):
141
147
if len (image_list ) <= 1 :
142
148
return image_list [0 ]
143
149
if not is_no_channel (meta_dict .get (MetaKeys .ORIGINAL_CHANNEL_DIM , None )):
144
150
channel_dim = int (meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ])
151
+ if to_cupy and has_cp :
152
+ return cp .concatenate (image_list , axis = channel_dim )
145
153
return np .concatenate (image_list , axis = channel_dim )
146
154
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
147
155
meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ] = 0
156
+ if to_cupy and has_cp :
157
+ return cp .stack (image_list , axis = 0 )
148
158
return np .stack (image_list , axis = 0 )
149
159
150
160
@@ -864,12 +874,18 @@ class NibabelReader(ImageReader):
864
874
Load NIfTI format images based on Nibabel library.
865
875
866
876
Args:
867
- as_closest_canonical: if True, load the image as closest to canonical axis format.
868
- squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
869
877
channel_dim: the channel dimension of the input image, default is None.
870
878
this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
871
879
if None, `original_channel_dim` will be either `no_channel` or `-1`.
872
880
most Nifti files are usually "channel last", no need to specify this argument for them.
881
+ as_closest_canonical: if True, load the image as closest to canonical axis format.
882
+ squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
883
+ to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
884
+ Default is False. CuPy and Kvikio are required for this option.
885
+ Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
886
+ and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887
+ In practical use, it's recommended to add a warm up call before the actual loading.
888
+ A related tutorial will be prepared in the future, and the document will be updated accordingly.
873
889
kwargs: additional args for `nibabel.load` API. more details about available args:
874
890
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
875
891
@@ -880,14 +896,42 @@ def __init__(
880
896
channel_dim : str | int | None = None ,
881
897
as_closest_canonical : bool = False ,
882
898
squeeze_non_spatial_dims : bool = False ,
899
+ to_gpu : bool = False ,
883
900
** kwargs ,
884
901
):
885
902
super ().__init__ ()
886
903
self .channel_dim = float ("nan" ) if channel_dim == "no_channel" else channel_dim
887
904
self .as_closest_canonical = as_closest_canonical
888
905
self .squeeze_non_spatial_dims = squeeze_non_spatial_dims
906
+ if to_gpu and (not has_cp or not has_kvikio ):
907
+ warnings .warn (
908
+ "NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
909
+ )
910
+ to_gpu = False
911
+
912
+ if to_gpu :
913
+ self .warmup_kvikio ()
914
+
915
+ self .to_gpu = to_gpu
889
916
self .kwargs = kwargs
890
917
918
+ def warmup_kvikio (self ):
919
+ """
920
+ Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
921
+ This can accelerate the data loading process when `to_gpu` is set to True.
922
+ """
923
+ if has_cp and has_kvikio :
924
+ a = cp .arange (100 )
925
+ with tempfile .NamedTemporaryFile () as tmp_file :
926
+ tmp_file_name = tmp_file .name
927
+ f = kvikio .CuFile (tmp_file_name , "w" )
928
+ f .write (a )
929
+ f .close ()
930
+
931
+ b = cp .empty_like (a )
932
+ f = kvikio .CuFile (tmp_file_name , "r" )
933
+ f .read (b )
934
+
891
935
def verify_suffix (self , filename : Sequence [PathLike ] | PathLike ) -> bool :
892
936
"""
893
937
Verify whether the specified file or files format is supported by Nibabel reader.
@@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
916
960
img_ : list [Nifti1Image ] = []
917
961
918
962
filenames : Sequence [PathLike ] = ensure_tuple (data )
963
+ self .filenames = filenames
919
964
kwargs_ = self .kwargs .copy ()
920
965
kwargs_ .update (kwargs )
921
966
for name in filenames :
@@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
936
981
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
937
982
938
983
"""
984
+ # TODO: the actual type is list[np.ndarray | cp.ndarray]
985
+ # should figure out how to define correct types without having cupy not found error
986
+ # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
939
987
img_array : list [np .ndarray ] = []
940
988
compatible_meta : dict = {}
941
989
942
- for i in ensure_tuple (img ):
990
+ for i , filename in zip ( ensure_tuple (img ), self . filenames ):
943
991
header = self ._get_meta_dict (i )
944
992
header [MetaKeys .AFFINE ] = self ._get_affine (i )
945
993
header [MetaKeys .ORIGINAL_AFFINE ] = self ._get_affine (i )
@@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
949
997
header [MetaKeys .AFFINE ] = self ._get_affine (i )
950
998
header [MetaKeys .SPATIAL_SHAPE ] = self ._get_spatial_shape (i )
951
999
header [MetaKeys .SPACE ] = SpaceKeys .RAS
952
- data = self ._get_array_data (i )
1000
+ data = self ._get_array_data (i , filename )
953
1001
if self .squeeze_non_spatial_dims :
954
1002
for d in range (len (data .shape ), len (header [MetaKeys .SPATIAL_SHAPE ]), - 1 ):
955
1003
if data .shape [d - 1 ] == 1 :
@@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
963
1011
header [MetaKeys .ORIGINAL_CHANNEL_DIM ] = self .channel_dim
964
1012
_copy_compatible_dict (header , compatible_meta )
965
1013
966
- return _stack_images (img_array , compatible_meta ), compatible_meta
1014
+ return _stack_images (img_array , compatible_meta , to_cupy = self . to_gpu ), compatible_meta
967
1015
968
1016
def _get_meta_dict (self , img ) -> dict :
969
1017
"""
@@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img):
1015
1063
spatial_rank = max (min (ndim , 3 ), 1 )
1016
1064
return np .asarray (size [:spatial_rank ])
1017
1065
1018
- def _get_array_data (self , img ):
1066
+ def _get_array_data (self , img , filename ):
1019
1067
"""
1020
1068
Get the raw array data of the image, converted to Numpy array.
1021
1069
1022
1070
Args:
1023
1071
img: a Nibabel image object loaded from an image file.
1024
-
1025
- """
1072
+ filename: file name of the image.
1073
+
1074
+ """
1075
+ if self .to_gpu :
1076
+ file_size = os .path .getsize (filename )
1077
+ image = cp .empty (file_size , dtype = cp .uint8 )
1078
+ with kvikio .CuFile (filename , "r" ) as f :
1079
+ f .read (image )
1080
+ if filename .endswith (".nii.gz" ):
1081
+ # for compressed data, have to tansfer to CPU to decompress
1082
+ # and then transfer back to GPU. It is not efficient compared to .nii file
1083
+ # and may be slower than CPU loading in some cases.
1084
+ warnings .warn ("Loading compressed NIfTI file into GPU may not be efficient." )
1085
+ compressed_data = cp .asnumpy (image )
1086
+ with gzip .GzipFile (fileobj = io .BytesIO (compressed_data )) as gz_file :
1087
+ decompressed_data = gz_file .read ()
1088
+
1089
+ image = cp .frombuffer (decompressed_data , dtype = cp .uint8 )
1090
+ data_shape = img .shape
1091
+ data_offset = img .dataobj .offset
1092
+ data_dtype = img .dataobj .dtype
1093
+ return image [data_offset :].view (data_dtype ).reshape (data_shape , order = "F" )
1026
1094
return np .asanyarray (img .dataobj , order = "C" )
1027
1095
1028
1096
0 commit comments