Skip to content

Commit e13e756

Browse files
authored
Add support to slice images and subjects (#1170)
* Add support to slice images and subjects * Update docs * Address issues * Ignore mypy errors
1 parent 4a06261 commit e13e756

File tree

7 files changed

+99
-2
lines changed

7 files changed

+99
-2
lines changed

docs/source/data/image.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ instantiation time.
2222
Instead, the data is only loaded when needed for an operation
2323
(e.g., if a transform is applied to the image).
2424

25+
Images can be sliced using the standard NumPy / PyTorch slicing syntax.
26+
This operation updates the coordinates origin in the affine matrix
27+
correspondingly.
28+
2529
The figure below shows two instances of :class:`Image`.
2630
The instance of :class:`ScalarImage` contains a 4D tensor representing a
2731
diffusion MRI, which contains four 3D volumes (one per gradient direction),

docs/source/data/subject.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ The :class:`Subject` is a data structure used to store
88
images associated with a subject and any other metadata necessary for
99
processing.
1010

11+
Subject objects can be sliced using the standard NumPy / PyTorch slicing
12+
syntax, returning a new subject with sliced images.
13+
This is only possible if all images in the subject have the same spatial
14+
shape.
15+
1116
All transforms applied to a :class:`Subject` are saved
1217
in its :attr:`history` attribute (see :ref:`Reproducibility`).
1318

src/torchio/data/image.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
from ..typing import TypeDirection3D
3131
from ..typing import TypePath
3232
from ..typing import TypeQuartetInt
33+
from ..typing import TypeSlice
3334
from ..typing import TypeTripletFloat
3435
from ..typing import TypeTripletInt
3536
from ..utils import get_stem
37+
from ..utils import to_tuple
3638
from ..utils import guess_external_viewer
3739
from ..utils import is_iterable
3840
from .io import check_uint_to_int
@@ -201,6 +203,9 @@ def __repr__(self):
201203
return string
202204

203205
def __getitem__(self, item):
206+
if isinstance(item, (slice, int, tuple)):
207+
return self._crop_from_slices(item)
208+
204209
if item in (DATA, AFFINE):
205210
if item not in self:
206211
self.load()
@@ -784,6 +789,44 @@ def show(self, viewer_path: Optional[TypePath] = None) -> None:
784789
image_viewer.SetApplication(str(viewer_path))
785790
image_viewer.Execute(sitk_image)
786791

792+
def _crop_from_slices(
793+
self,
794+
slices: Union[TypeSlice, Tuple[TypeSlice, ...]],
795+
) -> 'Image':
796+
from ..transforms import Crop
797+
798+
slices_tuple = to_tuple(slices) # type: ignore[assignment]
799+
cropping: List[int] = []
800+
for dim, slice_ in enumerate(slices_tuple):
801+
if isinstance(slice_, slice):
802+
pass
803+
elif slice_ is Ellipsis:
804+
message = 'Ellipsis slicing is not supported yet'
805+
raise NotImplementedError(message)
806+
elif isinstance(slice_, int):
807+
slice_ = slice(slice_, slice_ + 1) # type: ignore[assignment]
808+
else:
809+
message = f'Slice type not understood: "{type(slice_)}"'
810+
raise TypeError(message)
811+
shape_dim = self.spatial_shape[dim]
812+
assert isinstance(slice_, slice)
813+
start, stop, step = slice_.indices(shape_dim)
814+
if step != 1:
815+
message = (
816+
'Slicing with steps different from 1 is not supported yet.'
817+
' Use the Crop transform instead'
818+
)
819+
raise ValueError(message)
820+
crop_ini = start
821+
crop_fin = shape_dim - stop
822+
cropping.extend([crop_ini, crop_fin])
823+
while dim < 2:
824+
cropping.extend([0, 0])
825+
dim += 1
826+
w_ini, w_fin, h_ini, h_fin, d_ini, d_fin = cropping
827+
cropping_arg = w_ini, w_fin, h_ini, h_fin, d_ini, d_fin # making mypy happy
828+
return Crop(cropping_arg)(self) # type: ignore[return-value]
829+
787830

788831
class ScalarImage(Image):
789832
"""Image whose pixel values represent scalars.

src/torchio/data/subject.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,23 @@ def __copy__(self):
7777
def __len__(self):
7878
return len(self.get_images(intensity_only=False))
7979

80+
def __getitem__(self, item):
81+
if isinstance(item, (slice, int, tuple)):
82+
try:
83+
self.check_consistent_spatial_shape()
84+
except RuntimeError as e:
85+
message = (
86+
'To use indexing, all images in the subject must have the'
87+
' same spatial shape'
88+
)
89+
raise RuntimeError(message) from e
90+
copied = copy.deepcopy(self)
91+
for image_name, image in copied.items():
92+
copied[image_name] = image[item]
93+
return copied
94+
else:
95+
return super().__getitem__(item)
96+
8097
@staticmethod
8198
def _parse_images(images: List[Image]) -> None:
8299
# Check that it's not empty

src/torchio/transforms/preprocessing/spatial/resample.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,17 @@ def apply_transform(self, subject: Subject) -> Subject:
158158
self.check_affine_key_presence(self.pre_affine_name, subject)
159159

160160
for image in self.get_images(subject):
161-
# Do not resample the reference image if it is in the subject
161+
# If the current image is the reference, don't resample it
162162
if self.target is image:
163163
continue
164+
165+
# If the target is not a string, or is not an image in the subject,
166+
# do nothing
164167
try:
165168
target_image = subject[self.target]
166169
if target_image is image:
167170
continue
168-
except (KeyError, TypeError):
171+
except (KeyError, TypeError, RuntimeError):
169172
pass
170173

171174
# Choose interpolation

src/torchio/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TypeKeys = Optional[Sequence[str]]
1616
TypeData = Union[torch.Tensor, np.ndarray]
1717
TypeDataAffine = Tuple[torch.Tensor, np.ndarray]
18+
TypeSlice = Union[int, slice]
1819

1920
TypeDoubletInt = Tuple[int, int]
2021
TypeTripletInt = Tuple[int, int, int]

tests/data/test_image.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,27 @@ def test_copy_no_data(self):
283283
new_image = copy.copy(my_image)
284284
assert my_image._loaded
285285
assert new_image._loaded
286+
287+
def test_slicing(self):
288+
path = self.get_image_path('im_slicing')
289+
image = tio.ScalarImage(path)
290+
291+
assert image.shape == (1, 10, 20, 30)
292+
293+
cropped = image[0]
294+
assert cropped.shape == (1, 1, 20, 30)
295+
296+
cropped = image[:, 2:-3]
297+
assert cropped.shape == (1, 10, 15, 30)
298+
299+
cropped = image[-5:, 5:]
300+
assert cropped.shape == (1, 5, 15, 30)
301+
302+
with pytest.raises(NotImplementedError):
303+
image[..., 5]
304+
305+
with pytest.raises(ValueError):
306+
image[0:8:-1]
307+
308+
with pytest.raises(ValueError):
309+
image[3::-1]

0 commit comments

Comments
 (0)