diff --git a/.all-contributorsrc b/.all-contributorsrc
index 92832ff8..2a6692b9 100644
--- a/.all-contributorsrc
+++ b/.all-contributorsrc
@@ -815,6 +815,15 @@
"contributions": [
"bug"
]
+ },
+ {
+ "login": "rickymwalsh",
+ "name": "Ricky Walsh",
+ "avatar_url": "https://avatars.githubusercontent.com/u/70853488?v=4",
+ "profile": "https://www.linkedin.com/in/ricky-walsh/",
+ "contributions": [
+ "code"
+ ]
}
],
"contributorsPerLine": 7,
diff --git a/.bumpversion.cfg b/.bumpversion.cfg
index e98d531a..108a32a5 100644
--- a/.bumpversion.cfg
+++ b/.bumpversion.cfg
@@ -1,5 +1,5 @@
[bumpversion]
-current_version = 0.19.9
+current_version = 0.20.0
commit = True
tag = True
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 145e9d1e..45eaa58f 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -40,8 +40,6 @@ jobs:
- name: Publish package to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
- user: __token__
- password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository-url: https://test.pypi.org/legacy/
verbose: true
skip-existing: true
@@ -51,5 +49,4 @@ jobs:
if: startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
- password: ${{ secrets.PYPI_API_TOKEN }}
verbose: true
diff --git a/README.md b/README.md
index bbb256af..0763c9fc 100644
--- a/README.md
+++ b/README.md
@@ -408,6 +408,7 @@ Thanks goes to all these people ([emoji key](https://allcontributors.org/docs/en
marius-sm 🤔 |
haarisr 💻 |
Chris Winder 🐛 |
+ Ricky Walsh 💻 |
diff --git a/docs/examples/plot_3d_to_2d.py b/docs/examples/plot_3d_to_2d.py
index 7046b8b7..fb26c78e 100644
--- a/docs/examples/plot_3d_to_2d.py
+++ b/docs/examples/plot_3d_to_2d.py
@@ -32,7 +32,7 @@
def plot_batch(sampler):
queue = tio.Queue(dataset, max_queue_length, patches_per_volume, sampler)
- loader = torch.utils.data.DataLoader(queue, batch_size=16)
+ loader = tio.SubjectsLoader(queue, batch_size=16)
batch = tio.utils.get_first_item(loader)
fig, axes = plt.subplots(4, 4, figsize=(12, 10))
diff --git a/docs/examples/plot_history.py b/docs/examples/plot_history.py
index 0b39f8a7..3af4424a 100644
--- a/docs/examples/plot_history.py
+++ b/docs/examples/plot_history.py
@@ -43,7 +43,7 @@
) # noqa: T201, B950
print(transformed.get_inverse_transform(ignore_intensity=False)) # noqa: T201
-loader = torch.utils.data.DataLoader(
+loader = tio.SubjectsLoader(
dataset,
batch_size=batch_size,
collate_fn=tio.utils.history_collate,
diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst
index fa0a39ed..f5bf52fe 100644
--- a/docs/source/datasets.rst
+++ b/docs/source/datasets.rst
@@ -206,7 +206,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.OrganMNIST3D('train')
- loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
+ loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
@@ -224,7 +224,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.NoduleMNIST3D('train')
- loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
+ loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
@@ -242,7 +242,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.AdrenalMNIST3D('train')
- loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
+ loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
@@ -260,7 +260,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.FractureMNIST3D('train')
- loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
+ loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
@@ -278,7 +278,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.VesselMNIST3D('train')
- loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
+ loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
@@ -296,7 +296,7 @@ MedMNIST
from einops import rearrange
rows, cols = 16, 28
dataset = tio.datasets.SynapseMNIST3D('train')
- loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols)
+ loader = tio.SubjectsLoader(dataset, batch_size=rows * cols)
batch = tio.utils.get_first_item(loader)
tensor = batch['image'][tio.DATA]
pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)'
diff --git a/docs/source/patches/patch_inference.rst b/docs/source/patches/patch_inference.rst
index 836ad512..98e1d77e 100644
--- a/docs/source/patches/patch_inference.rst
+++ b/docs/source/patches/patch_inference.rst
@@ -17,7 +17,7 @@ inference across a 3D image using patches::
... patch_size,
... patch_overlap,
... )
- >>> patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
+ >>> patch_loader = tio.SubjectsLoader(grid_sampler, batch_size=4)
>>> aggregator = tio.inference.GridAggregator(grid_sampler)
>>> model = nn.Identity().eval()
>>> with torch.no_grad():
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index e09bc816..e84fd6e3 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -37,13 +37,12 @@ Hello, World!
This example shows the basic usage of TorchIO, where an instance of
:class:`~torchio.SubjectsDataset` is passed to
-a PyTorch :class:`~torch.utils.data.DataLoader` to generate training batches
+a PyTorch :class:`~torch.SubjectsLoader` to generate training batches
of 3D images that are loaded, preprocessed and augmented on the fly,
in parallel::
import torch
import torchio as tio
- from torch.utils.data import DataLoader
# Each instance of tio.Subject is passed arbitrary keyword arguments.
# Typically, these arguments will be instances of tio.Image
@@ -91,8 +90,14 @@ in parallel::
# SubjectsDataset is a subclass of torch.data.utils.Dataset
subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
- # Images are processed in parallel thanks to a PyTorch DataLoader
- training_loader = DataLoader(subjects_dataset, batch_size=4, num_workers=4)
+ # Images are processed in parallel thanks to a SubjectsLoader
+ # (which inherits from torch.utils.data.DataLoader)
+ training_loader = tio.SubjectsLoader(
+ subjects_dataset,
+ batch_size=4,
+ num_workers=4,
+ shuffle=True,
+ )
# Training epoch
for subjects_batch in training_loader:
diff --git a/setup.cfg b/setup.cfg
index 4e54a9de..847c26a8 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,6 @@
[metadata]
name = torchio
-version = 0.19.9
+version = 0.20.0
description = Tools for medical image processing with PyTorch
long_description = file: README.md
long_description_content_type = text/markdown
diff --git a/src/torchio/__init__.py b/src/torchio/__init__.py
index 1a1f0403..85de897c 100644
--- a/src/torchio/__init__.py
+++ b/src/torchio/__init__.py
@@ -2,7 +2,7 @@
__author__ = """Fernando Perez-Garcia"""
__email__ = 'fepegar@gmail.com'
-__version__ = '0.19.9'
+__version__ = '0.20.0'
from . import utils
@@ -13,6 +13,7 @@
sampler,
inference,
SubjectsDataset,
+ SubjectsLoader,
Image,
ScalarImage,
LabelMap,
@@ -34,6 +35,7 @@
'sampler',
'inference',
'SubjectsDataset',
+ 'SubjectsLoader',
'Image',
'ScalarImage',
'LabelMap',
diff --git a/src/torchio/data/__init__.py b/src/torchio/data/__init__.py
index f97c42e8..7ebc6724 100644
--- a/src/torchio/data/__init__.py
+++ b/src/torchio/data/__init__.py
@@ -3,6 +3,7 @@
from .image import LabelMap
from .image import ScalarImage
from .inference import GridAggregator
+from .loader import SubjectsLoader
from .queue import Queue
from .sampler import GridSampler
from .sampler import LabelSampler
@@ -16,6 +17,7 @@
'Queue',
'Subject',
'SubjectsDataset',
+ 'SubjectsLoader',
'Image',
'ScalarImage',
'LabelMap',
diff --git a/src/torchio/data/dataset.py b/src/torchio/data/dataset.py
index 4a0e4c31..d1bdfc23 100644
--- a/src/torchio/data/dataset.py
+++ b/src/torchio/data/dataset.py
@@ -17,8 +17,8 @@ class SubjectsDataset(Dataset):
"""Base TorchIO dataset.
Reader of 3D medical images that directly inherits from the PyTorch
- :class:`~torch.utils.data.Dataset`. It can be used with a PyTorch
- :class:`~torch.utils.data.DataLoader` for efficient loading and
+ :class:`~torch.utils.data.Dataset`. It can be used with a
+ :class:`~tio.SubjectsLoader` for efficient loading and
augmentation. It receives a list of instances of :class:`~torchio.Subject`
and an optional transform applied to the volumes after loading.
diff --git a/src/torchio/data/inference/aggregator.py b/src/torchio/data/inference/aggregator.py
index c1c0ae9b..52985cad 100644
--- a/src/torchio/data/inference/aggregator.py
+++ b/src/torchio/data/inference/aggregator.py
@@ -139,8 +139,8 @@ def add_batch(
extracted using ``batch[torchio.LOCATION]``.
"""
batch = batch_tensor.cpu()
- locations = locations.cpu().numpy()
- patch_sizes = locations[:, 3:] - locations[:, :3]
+ locations_array = locations.cpu().numpy()
+ patch_sizes = locations_array[:, 3:] - locations_array[:, :3]
# There should be only one patch size
assert len(np.unique(patch_sizes, axis=0)) == 1
input_spatial_shape = tuple(batch.shape[-3:])
@@ -155,7 +155,7 @@ def add_batch(
self._initialize_output_tensor(batch)
assert isinstance(self._output_tensor, torch.Tensor)
if self.overlap_mode == 'crop':
- for patch, location in zip(batch, locations):
+ for patch, location in zip(batch, locations_array):
cropped_patch, new_location = self._crop_patch(
patch,
location,
diff --git a/src/torchio/data/io.py b/src/torchio/data/io.py
index dd9704ab..d463f492 100644
--- a/src/torchio/data/io.py
+++ b/src/torchio/data/io.py
@@ -181,18 +181,19 @@ def _write_sitk(
) -> None:
assert tensor.ndim == 4
path = Path(path)
+ array = tensor.numpy()
if path.suffix in ('.png', '.jpg', '.jpeg', '.bmp'):
warnings.warn(
f'Casting to uint 8 before saving to {path}',
RuntimeWarning,
stacklevel=2,
)
- tensor = tensor.numpy().astype(np.uint8)
+ array = array.astype(np.uint8)
if squeeze is None:
force_3d = path.suffix not in IMAGE_2D_FORMATS
else:
force_3d = not squeeze
- image = nib_to_sitk(tensor, affine, force_3d=force_3d)
+ image = nib_to_sitk(array, affine, force_3d=force_3d)
sitk.WriteImage(image, str(path), use_compression)
diff --git a/src/torchio/data/loader.py b/src/torchio/data/loader.py
new file mode 100644
index 00000000..ebc55dea
--- /dev/null
+++ b/src/torchio/data/loader.py
@@ -0,0 +1,64 @@
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import TypeVar
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+from torch.utils.data import DataLoader
+
+from .subject import Subject
+
+
+T = TypeVar('T')
+
+
+class SubjectsLoader(DataLoader):
+ def __init__(
+ self,
+ dataset: Dataset,
+ collate_fn: Optional[Callable[[List[T]], Any]] = None,
+ **kwargs,
+ ):
+ if collate_fn is None:
+ collate_fn = self._collate # type: ignore[assignment]
+ super().__init__(
+ dataset=dataset,
+ collate_fn=collate_fn,
+ **kwargs,
+ )
+
+ @staticmethod
+ def _collate(subjects: List[Subject]) -> Dict[str, Any]:
+ first_subject = subjects[0]
+ batch_dict = {}
+ for key in first_subject.keys():
+ collated_value = _stack([subject[key] for subject in subjects])
+ batch_dict[key] = collated_value
+ return batch_dict
+
+
+def _stack(x):
+ """Determine the type of the input and stack it accordingly.
+
+ Args:
+ x: List of elements to stack.
+ Returns:
+ Stacked elements, as either a torch.Tensor, np.ndarray, dict or list.
+ """
+ first_element = x[0]
+ if isinstance(first_element, torch.Tensor):
+ return torch.stack(x, dim=0)
+ elif isinstance(first_element, np.ndarray):
+ return np.stack(x, axis=0)
+ elif isinstance(first_element, dict):
+ # Assume that all elements have the same keys
+ collated_dict = {}
+ for key in first_element.keys():
+ collated_dict[key] = _stack([element[key] for element in x])
+ return collated_dict
+ else:
+ return x
diff --git a/src/torchio/data/queue.py b/src/torchio/data/queue.py
index 7f6b5872..6a77c8b0 100644
--- a/src/torchio/data/queue.py
+++ b/src/torchio/data/queue.py
@@ -115,7 +115,6 @@ class Queue(Dataset):
>>> import torch
>>> import torchio as tio
- >>> from torch.utils.data import DataLoader
>>> patch_size = 96
>>> queue_length = 300
>>> samples_per_volume = 10
@@ -129,7 +128,7 @@ class Queue(Dataset):
... sampler,
... num_workers=4,
... )
- >>> patches_loader = DataLoader(
+ >>> patches_loader = tio.SubjectsLoader(
... patches_queue,
... batch_size=16,
... num_workers=0, # this must be 0
@@ -168,7 +167,7 @@ class Queue(Dataset):
... num_workers=4,
... subject_sampler=subject_sampler,
... )
- >>> patches_loader = DataLoader(
+ >>> patches_loader = tio.SubjectsLoader(
... patches_queue,
... batch_size=16,
... num_workers=0, # this must be 0
diff --git a/src/torchio/data/subject.py b/src/torchio/data/subject.py
index 508dd222..767e1c60 100644
--- a/src/torchio/data/subject.py
+++ b/src/torchio/data/subject.py
@@ -442,7 +442,7 @@ def _subject_copy_helper(
value = copy.deepcopy(value)
result_dict[key] = value
- new = new_subj_cls(**result_dict)
+ new = new_subj_cls(**result_dict) # type: ignore[call-arg]
new.applied_transforms = old_obj.applied_transforms[:]
return new
diff --git a/src/torchio/transforms/preprocessing/intensity/rescale.py b/src/torchio/transforms/preprocessing/intensity/rescale.py
index 012cf22e..cdbdaee3 100644
--- a/src/torchio/transforms/preprocessing/intensity/rescale.py
+++ b/src/torchio/transforms/preprocessing/intensity/rescale.py
@@ -4,6 +4,7 @@
import numpy as np
import torch
+from ....data.image import Image
from ....data.subject import Subject
from ....typing import TypeDoubleFloat
from .normalization_transform import NormalizationTransform
@@ -84,7 +85,7 @@ def apply_normalization(
image_name: str,
mask: torch.Tensor,
) -> None:
- image = subject[image_name]
+ image: Image = subject[image_name]
image.set_data(self.rescale(image.data, mask, image_name))
def rescale(
@@ -95,8 +96,8 @@ def rescale(
) -> torch.Tensor:
# The tensor is cloned as in-place operations will be used
array = tensor.clone().float().numpy()
- mask = mask.numpy()
- if not mask.any():
+ mask_array = mask.numpy()
+ if not mask_array.any():
message = (
f'Rescaling image "{image_name}" not possible'
' because the mask to compute the statistics is empty'
@@ -104,7 +105,7 @@ def rescale(
warnings.warn(message, RuntimeWarning, stacklevel=2)
return tensor
- values = array[mask]
+ values = array[mask_array]
cutoff = np.percentile(values, self.percentiles)
np.clip(array, *cutoff, out=array) # type: ignore[call-overload]
diff --git a/src/torchio/utils.py b/src/torchio/utils.py
index 4e65aef2..6f7eecc4 100644
--- a/src/torchio/utils.py
+++ b/src/torchio/utils.py
@@ -20,6 +20,7 @@
import numpy as np
import SimpleITK as sitk
import torch
+from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
from tqdm.auto import trange
@@ -239,7 +240,7 @@ def get_subclasses(target_class: type) -> List[type]:
return subclasses
-def get_first_item(data_loader: torch.utils.data.DataLoader):
+def get_first_item(data_loader: DataLoader):
return next(iter(data_loader))
@@ -247,7 +248,7 @@ def get_batch_images_and_size(batch: Dict) -> Tuple[List[str], int]:
"""Get number of images and images names in a batch.
Args:
- batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
+ batch: Dictionary generated by a :class:`tio.SubjectsLoader`
extracting data from a :class:`torchio.SubjectsDataset`.
Raises:
@@ -268,7 +269,7 @@ def get_subjects_from_batch(batch: Dict) -> List:
"""Get list of subjects from collated batch.
Args:
- batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
+ batch: Dictionary generated by a :class:`tio.SubjectsLoader`
extracting data from a :class:`torchio.SubjectsDataset`.
"""
from .data import ScalarImage, LabelMap, Subject
diff --git a/tests/data/inference/test_aggregator.py b/tests/data/inference/test_aggregator.py
index 35737477..b1523ff8 100644
--- a/tests/data/inference/test_aggregator.py
+++ b/tests/data/inference/test_aggregator.py
@@ -18,7 +18,7 @@ def aggregate(self, mode, fixture):
patch_overlap = 0, 2, 2
sampler = tio.data.GridSampler(subject, patch_size, patch_overlap)
aggregator = tio.data.GridAggregator(sampler, overlap_mode=mode)
- loader = torch.utils.data.DataLoader(sampler, batch_size=3)
+ loader = tio.SubjectsLoader(sampler, batch_size=3)
values_dict = {
(0, 0): 0,
(0, 1): 2,
@@ -70,7 +70,7 @@ def run_sampler_aggregator(self, overlap_mode='crop'):
patch_size,
patch_overlap,
)
- patch_loader = torch.utils.data.DataLoader(grid_sampler)
+ patch_loader = tio.SubjectsLoader(grid_sampler)
aggregator = tio.inference.GridAggregator(
grid_sampler,
overlap_mode=overlap_mode,
@@ -102,7 +102,7 @@ def run_patch_crop_issue(self, *, padding_mode):
patch_size,
patch_overlap,
)
- patch_loader = torch.utils.data.DataLoader(grid_sampler)
+ patch_loader = tio.SubjectsLoader(grid_sampler)
aggregator = tio.inference.GridAggregator(grid_sampler)
for patches_batch in patch_loader:
input_tensor = patches_batch['image'][tio.DATA]
@@ -131,7 +131,7 @@ def test_bad_aggregator_shape(self):
padding_mode='edge',
)
aggregator = tio.data.GridAggregator(sampler)
- loader = torch.utils.data.DataLoader(sampler, batch_size=3)
+ loader = tio.SubjectsLoader(sampler, batch_size=3)
for batch in loader:
input_batch = batch[image_name][tio.DATA]
crop = tio.CropOrPad(12)
diff --git a/tests/data/inference/test_inference.py b/tests/data/inference/test_inference.py
index 71abfdc6..8f20ed5d 100644
--- a/tests/data/inference/test_inference.py
+++ b/tests/data/inference/test_inference.py
@@ -1,4 +1,4 @@
-from torch.utils.data import DataLoader
+import torchio as tio
from torchio import DATA
from torchio import LOCATION
from torchio.data.inference import GridAggregator
@@ -29,7 +29,7 @@ def try_inference(self, padding_mode):
padding_mode=padding_mode,
)
aggregator = GridAggregator(grid_sampler)
- patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
+ patch_loader = tio.SubjectsLoader(grid_sampler, batch_size=batch_size)
for patches_batch in patch_loader:
input_tensor = patches_batch['t1'][DATA]
locations = patches_batch[LOCATION]
diff --git a/tests/data/test_queue.py b/tests/data/test_queue.py
index 2bfd18f3..1e6ae49d 100644
--- a/tests/data/test_queue.py
+++ b/tests/data/test_queue.py
@@ -4,7 +4,6 @@
import torch
import torchio as tio
from parameterized import parameterized
-from torch.utils.data import DataLoader
from torchio.data import UniformSampler
from torchio.utils import create_dummy_dataset
@@ -37,7 +36,7 @@ def run_queue(self, num_workers=0, **kwargs):
**kwargs,
)
_ = str(queue_dataset)
- batch_loader = DataLoader(queue_dataset, batch_size=4)
+ batch_loader = tio.SubjectsLoader(queue_dataset, batch_size=4)
for batch in batch_loader:
_ = batch['one_modality'][tio.DATA]
_ = batch['segmentation'][tio.DATA]
@@ -69,7 +68,7 @@ def test_different_samples_per_volume(self, max_length):
sampler=sampler,
shuffle_patches=False,
)
- batch_loader = DataLoader(queue_dataset, batch_size=6)
+ batch_loader = tio.SubjectsLoader(queue_dataset, batch_size=6)
tensors = [batch['im'][tio.DATA] for batch in batch_loader]
all_numbers = torch.stack(tensors).flatten().tolist()
assert all_numbers.count(10) == 10
diff --git a/tests/data/test_subject.py b/tests/data/test_subject.py
index e94fbb98..8a9cb747 100644
--- a/tests/data/test_subject.py
+++ b/tests/data/test_subject.py
@@ -6,7 +6,6 @@
import pytest
import torch
import torchio as tio
-from torch.utils.data import DataLoader
from ..utils import TorchioTestCase
@@ -183,6 +182,6 @@ def test_load_unload(self):
def test_subjects_batch(self):
subjects = tio.SubjectsDataset(10 * [self.sample_subject])
- loader = DataLoader(subjects, batch_size=4)
+ loader = tio.SubjectsLoader(subjects, batch_size=4)
batch = next(iter(loader))
assert batch.__class__ is dict
diff --git a/tests/data/test_subjects_dataset.py b/tests/data/test_subjects_dataset.py
index 2c5fd231..9187238b 100644
--- a/tests/data/test_subjects_dataset.py
+++ b/tests/data/test_subjects_dataset.py
@@ -1,7 +1,6 @@
import pytest
import torch
import torchio as tio
-from torch.utils.data import DataLoader
from ..utils import TorchioTestCase
@@ -57,7 +56,7 @@ def iterate_dataset(subjects_list):
def test_from_batch(self):
dataset = tio.SubjectsDataset([self.sample_subject])
- loader = DataLoader(dataset)
+ loader = tio.SubjectsLoader(dataset)
batch = tio.utils.get_first_item(loader)
new_dataset = tio.SubjectsDataset.from_batch(batch)
self.assert_tensor_equal(
diff --git a/tests/datasets/test_medmnist.py b/tests/datasets/test_medmnist.py
index 2d527877..680eec49 100644
--- a/tests/datasets/test_medmnist.py
+++ b/tests/datasets/test_medmnist.py
@@ -1,7 +1,7 @@
import os
import pytest
-import torch
+import torchio as tio
from torchio.datasets.medmnist import AdrenalMNIST3D
from torchio.datasets.medmnist import FractureMNIST3D
from torchio.datasets.medmnist import NoduleMNIST3D
@@ -26,7 +26,7 @@
@pytest.mark.parametrize('split', ('train', 'val', 'test'))
def test_load_all(class_, split):
dataset = class_(split)
- loader = torch.utils.data.DataLoader(
+ loader = tio.SubjectsLoader(
dataset,
batch_size=256,
)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 7bda672c..6070896a 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -43,7 +43,7 @@ def test_apply_transform_to_file(self):
def test_subjects_from_batch(self):
dataset = tio.SubjectsDataset(4 * [self.sample_subject])
- loader = torch.utils.data.DataLoader(dataset, batch_size=4)
+ loader = tio.SubjectsLoader(dataset, batch_size=4)
batch = tio.utils.get_first_item(loader)
subjects = tio.utils.get_subjects_from_batch(batch)
assert isinstance(subjects[0], tio.Subject)
@@ -55,7 +55,7 @@ def test_subjects_from_batch_with_string_metadata(self):
)
dataset = tio.SubjectsDataset(4 * [subject_c_with_string_metadata])
- loader = torch.utils.data.DataLoader(dataset, batch_size=4)
+ loader = tio.SubjectsLoader(dataset, batch_size=4)
batch = tio.utils.get_first_item(loader)
subjects = tio.utils.get_subjects_from_batch(batch)
assert isinstance(subjects[0], tio.Subject)
@@ -68,7 +68,7 @@ def test_subjects_from_batch_with_int_metadata(self):
label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
)
dataset = tio.SubjectsDataset(4 * [subject_c_with_int_metadata])
- loader = torch.utils.data.DataLoader(dataset, batch_size=4)
+ loader = tio.SubjectsLoader(dataset, batch_size=4)
batch = tio.utils.get_first_item(loader)
subjects = tio.utils.get_subjects_from_batch(batch)
assert isinstance(subjects[0], tio.Subject)
diff --git a/tests/transforms/test_collate.py b/tests/transforms/test_collate.py
index eed377d6..25e95ebe 100644
--- a/tests/transforms/test_collate.py
+++ b/tests/transforms/test_collate.py
@@ -1,5 +1,4 @@
import torchio as tio
-from torch.utils.data import DataLoader
from ..utils import TorchioTestCase
@@ -28,11 +27,11 @@ def __getitem__(self, index):
return Dataset(data)
def test_collate(self):
- loader = DataLoader(self.get_heterogeneous_dataset(), batch_size=2)
+ loader = tio.SubjectsLoader(self.get_heterogeneous_dataset(), batch_size=2)
tio.utils.get_first_item(loader)
def test_history_collate(self):
- loader = DataLoader(
+ loader = tio.SubjectsLoader(
self.get_heterogeneous_dataset(),
batch_size=4,
collate_fn=tio.utils.history_collate,
diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py
index 9fdf5b83..4c424c0e 100644
--- a/tests/transforms/test_transforms.py
+++ b/tests/transforms/test_transforms.py
@@ -291,7 +291,7 @@ def test_batch_history(self):
]
)
dataset = tio.SubjectsDataset([subject], transform=transform)
- loader = torch.utils.data.DataLoader(
+ loader = tio.SubjectsLoader(
dataset,
collate_fn=tio.utils.history_collate,
)
diff --git a/tutorials/example_heteromodal.py b/tutorials/example_heteromodal.py
index a50de35f..4a3fa06a 100644
--- a/tutorials/example_heteromodal.py
+++ b/tutorials/example_heteromodal.py
@@ -8,7 +8,7 @@
import logging
import torch.nn as nn
-from torch.utils.data import DataLoader
+import torchio as tio
from torchio import LabelMap
from torchio import Queue
from torchio import ScalarImage
@@ -54,7 +54,7 @@ def main():
# This collate_fn is needed in the case of missing modalities
# In this case, the batch will be composed by a *list* of samples instead
# of the typical Python dictionary that is collated by default in Pytorch
- batch_loader = DataLoader(
+ batch_loader = tio.SubjectsLoader(
queue_dataset,
batch_size=batch_size,
collate_fn=lambda x: x,