diff --git a/.all-contributorsrc b/.all-contributorsrc index 92832ff8..2a6692b9 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -815,6 +815,15 @@ "contributions": [ "bug" ] + }, + { + "login": "rickymwalsh", + "name": "Ricky Walsh", + "avatar_url": "https://avatars.githubusercontent.com/u/70853488?v=4", + "profile": "https://www.linkedin.com/in/ricky-walsh/", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, diff --git a/.bumpversion.cfg b/.bumpversion.cfg index e98d531a..108a32a5 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.19.9 +current_version = 0.20.0 commit = True tag = True diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 145e9d1e..45eaa58f 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -40,8 +40,6 @@ jobs: - name: Publish package to TestPyPI uses: pypa/gh-action-pypi-publish@release/v1 with: - user: __token__ - password: ${{ secrets.TEST_PYPI_API_TOKEN }} repository-url: https://test.pypi.org/legacy/ verbose: true skip-existing: true @@ -51,5 +49,4 @@ jobs: if: startsWith(github.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@release/v1 with: - password: ${{ secrets.PYPI_API_TOKEN }} verbose: true diff --git a/README.md b/README.md index bbb256af..0763c9fc 100644 --- a/README.md +++ b/README.md @@ -408,6 +408,7 @@ Thanks goes to all these people ([emoji key](https://allcontributors.org/docs/en marius-sm
marius-sm

🤔 haarisr
haarisr

💻 Chris Winder
Chris Winder

🐛 + Ricky Walsh
Ricky Walsh

💻 diff --git a/docs/examples/plot_3d_to_2d.py b/docs/examples/plot_3d_to_2d.py index 7046b8b7..fb26c78e 100644 --- a/docs/examples/plot_3d_to_2d.py +++ b/docs/examples/plot_3d_to_2d.py @@ -32,7 +32,7 @@ def plot_batch(sampler): queue = tio.Queue(dataset, max_queue_length, patches_per_volume, sampler) - loader = torch.utils.data.DataLoader(queue, batch_size=16) + loader = tio.SubjectsLoader(queue, batch_size=16) batch = tio.utils.get_first_item(loader) fig, axes = plt.subplots(4, 4, figsize=(12, 10)) diff --git a/docs/examples/plot_history.py b/docs/examples/plot_history.py index 0b39f8a7..3af4424a 100644 --- a/docs/examples/plot_history.py +++ b/docs/examples/plot_history.py @@ -43,7 +43,7 @@ ) # noqa: T201, B950 print(transformed.get_inverse_transform(ignore_intensity=False)) # noqa: T201 -loader = torch.utils.data.DataLoader( +loader = tio.SubjectsLoader( dataset, batch_size=batch_size, collate_fn=tio.utils.history_collate, diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index fa0a39ed..f5bf52fe 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -206,7 +206,7 @@ MedMNIST from einops import rearrange rows, cols = 16, 28 dataset = tio.datasets.OrganMNIST3D('train') - loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols) + loader = tio.SubjectsLoader(dataset, batch_size=rows * cols) batch = tio.utils.get_first_item(loader) tensor = batch['image'][tio.DATA] pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)' @@ -224,7 +224,7 @@ MedMNIST from einops import rearrange rows, cols = 16, 28 dataset = tio.datasets.NoduleMNIST3D('train') - loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols) + loader = tio.SubjectsLoader(dataset, batch_size=rows * cols) batch = tio.utils.get_first_item(loader) tensor = batch['image'][tio.DATA] pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)' @@ -242,7 +242,7 @@ MedMNIST from einops import rearrange rows, cols = 16, 28 dataset = tio.datasets.AdrenalMNIST3D('train') - loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols) + loader = tio.SubjectsLoader(dataset, batch_size=rows * cols) batch = tio.utils.get_first_item(loader) tensor = batch['image'][tio.DATA] pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)' @@ -260,7 +260,7 @@ MedMNIST from einops import rearrange rows, cols = 16, 28 dataset = tio.datasets.FractureMNIST3D('train') - loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols) + loader = tio.SubjectsLoader(dataset, batch_size=rows * cols) batch = tio.utils.get_first_item(loader) tensor = batch['image'][tio.DATA] pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)' @@ -278,7 +278,7 @@ MedMNIST from einops import rearrange rows, cols = 16, 28 dataset = tio.datasets.VesselMNIST3D('train') - loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols) + loader = tio.SubjectsLoader(dataset, batch_size=rows * cols) batch = tio.utils.get_first_item(loader) tensor = batch['image'][tio.DATA] pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)' @@ -296,7 +296,7 @@ MedMNIST from einops import rearrange rows, cols = 16, 28 dataset = tio.datasets.SynapseMNIST3D('train') - loader = torch.utils.data.DataLoader(dataset, batch_size=rows * cols) + loader = tio.SubjectsLoader(dataset, batch_size=rows * cols) batch = tio.utils.get_first_item(loader) tensor = batch['image'][tio.DATA] pattern = '(b1 b2) c x y z -> c x (b1 y) (b2 z)' diff --git a/docs/source/patches/patch_inference.rst b/docs/source/patches/patch_inference.rst index 836ad512..98e1d77e 100644 --- a/docs/source/patches/patch_inference.rst +++ b/docs/source/patches/patch_inference.rst @@ -17,7 +17,7 @@ inference across a 3D image using patches:: ... patch_size, ... patch_overlap, ... ) - >>> patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4) + >>> patch_loader = tio.SubjectsLoader(grid_sampler, batch_size=4) >>> aggregator = tio.inference.GridAggregator(grid_sampler) >>> model = nn.Identity().eval() >>> with torch.no_grad(): diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index e09bc816..e84fd6e3 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -37,13 +37,12 @@ Hello, World! This example shows the basic usage of TorchIO, where an instance of :class:`~torchio.SubjectsDataset` is passed to -a PyTorch :class:`~torch.utils.data.DataLoader` to generate training batches +a PyTorch :class:`~torch.SubjectsLoader` to generate training batches of 3D images that are loaded, preprocessed and augmented on the fly, in parallel:: import torch import torchio as tio - from torch.utils.data import DataLoader # Each instance of tio.Subject is passed arbitrary keyword arguments. # Typically, these arguments will be instances of tio.Image @@ -91,8 +90,14 @@ in parallel:: # SubjectsDataset is a subclass of torch.data.utils.Dataset subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform) - # Images are processed in parallel thanks to a PyTorch DataLoader - training_loader = DataLoader(subjects_dataset, batch_size=4, num_workers=4) + # Images are processed in parallel thanks to a SubjectsLoader + # (which inherits from torch.utils.data.DataLoader) + training_loader = tio.SubjectsLoader( + subjects_dataset, + batch_size=4, + num_workers=4, + shuffle=True, + ) # Training epoch for subjects_batch in training_loader: diff --git a/setup.cfg b/setup.cfg index 4e54a9de..847c26a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = torchio -version = 0.19.9 +version = 0.20.0 description = Tools for medical image processing with PyTorch long_description = file: README.md long_description_content_type = text/markdown diff --git a/src/torchio/__init__.py b/src/torchio/__init__.py index 1a1f0403..85de897c 100644 --- a/src/torchio/__init__.py +++ b/src/torchio/__init__.py @@ -2,7 +2,7 @@ __author__ = """Fernando Perez-Garcia""" __email__ = 'fepegar@gmail.com' -__version__ = '0.19.9' +__version__ = '0.20.0' from . import utils @@ -13,6 +13,7 @@ sampler, inference, SubjectsDataset, + SubjectsLoader, Image, ScalarImage, LabelMap, @@ -34,6 +35,7 @@ 'sampler', 'inference', 'SubjectsDataset', + 'SubjectsLoader', 'Image', 'ScalarImage', 'LabelMap', diff --git a/src/torchio/data/__init__.py b/src/torchio/data/__init__.py index f97c42e8..7ebc6724 100644 --- a/src/torchio/data/__init__.py +++ b/src/torchio/data/__init__.py @@ -3,6 +3,7 @@ from .image import LabelMap from .image import ScalarImage from .inference import GridAggregator +from .loader import SubjectsLoader from .queue import Queue from .sampler import GridSampler from .sampler import LabelSampler @@ -16,6 +17,7 @@ 'Queue', 'Subject', 'SubjectsDataset', + 'SubjectsLoader', 'Image', 'ScalarImage', 'LabelMap', diff --git a/src/torchio/data/dataset.py b/src/torchio/data/dataset.py index 4a0e4c31..d1bdfc23 100644 --- a/src/torchio/data/dataset.py +++ b/src/torchio/data/dataset.py @@ -17,8 +17,8 @@ class SubjectsDataset(Dataset): """Base TorchIO dataset. Reader of 3D medical images that directly inherits from the PyTorch - :class:`~torch.utils.data.Dataset`. It can be used with a PyTorch - :class:`~torch.utils.data.DataLoader` for efficient loading and + :class:`~torch.utils.data.Dataset`. It can be used with a + :class:`~tio.SubjectsLoader` for efficient loading and augmentation. It receives a list of instances of :class:`~torchio.Subject` and an optional transform applied to the volumes after loading. diff --git a/src/torchio/data/inference/aggregator.py b/src/torchio/data/inference/aggregator.py index c1c0ae9b..52985cad 100644 --- a/src/torchio/data/inference/aggregator.py +++ b/src/torchio/data/inference/aggregator.py @@ -139,8 +139,8 @@ def add_batch( extracted using ``batch[torchio.LOCATION]``. """ batch = batch_tensor.cpu() - locations = locations.cpu().numpy() - patch_sizes = locations[:, 3:] - locations[:, :3] + locations_array = locations.cpu().numpy() + patch_sizes = locations_array[:, 3:] - locations_array[:, :3] # There should be only one patch size assert len(np.unique(patch_sizes, axis=0)) == 1 input_spatial_shape = tuple(batch.shape[-3:]) @@ -155,7 +155,7 @@ def add_batch( self._initialize_output_tensor(batch) assert isinstance(self._output_tensor, torch.Tensor) if self.overlap_mode == 'crop': - for patch, location in zip(batch, locations): + for patch, location in zip(batch, locations_array): cropped_patch, new_location = self._crop_patch( patch, location, diff --git a/src/torchio/data/io.py b/src/torchio/data/io.py index dd9704ab..d463f492 100644 --- a/src/torchio/data/io.py +++ b/src/torchio/data/io.py @@ -181,18 +181,19 @@ def _write_sitk( ) -> None: assert tensor.ndim == 4 path = Path(path) + array = tensor.numpy() if path.suffix in ('.png', '.jpg', '.jpeg', '.bmp'): warnings.warn( f'Casting to uint 8 before saving to {path}', RuntimeWarning, stacklevel=2, ) - tensor = tensor.numpy().astype(np.uint8) + array = array.astype(np.uint8) if squeeze is None: force_3d = path.suffix not in IMAGE_2D_FORMATS else: force_3d = not squeeze - image = nib_to_sitk(tensor, affine, force_3d=force_3d) + image = nib_to_sitk(array, affine, force_3d=force_3d) sitk.WriteImage(image, str(path), use_compression) diff --git a/src/torchio/data/loader.py b/src/torchio/data/loader.py new file mode 100644 index 00000000..ebc55dea --- /dev/null +++ b/src/torchio/data/loader.py @@ -0,0 +1,64 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import TypeVar + +import numpy as np +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader + +from .subject import Subject + + +T = TypeVar('T') + + +class SubjectsLoader(DataLoader): + def __init__( + self, + dataset: Dataset, + collate_fn: Optional[Callable[[List[T]], Any]] = None, + **kwargs, + ): + if collate_fn is None: + collate_fn = self._collate # type: ignore[assignment] + super().__init__( + dataset=dataset, + collate_fn=collate_fn, + **kwargs, + ) + + @staticmethod + def _collate(subjects: List[Subject]) -> Dict[str, Any]: + first_subject = subjects[0] + batch_dict = {} + for key in first_subject.keys(): + collated_value = _stack([subject[key] for subject in subjects]) + batch_dict[key] = collated_value + return batch_dict + + +def _stack(x): + """Determine the type of the input and stack it accordingly. + + Args: + x: List of elements to stack. + Returns: + Stacked elements, as either a torch.Tensor, np.ndarray, dict or list. + """ + first_element = x[0] + if isinstance(first_element, torch.Tensor): + return torch.stack(x, dim=0) + elif isinstance(first_element, np.ndarray): + return np.stack(x, axis=0) + elif isinstance(first_element, dict): + # Assume that all elements have the same keys + collated_dict = {} + for key in first_element.keys(): + collated_dict[key] = _stack([element[key] for element in x]) + return collated_dict + else: + return x diff --git a/src/torchio/data/queue.py b/src/torchio/data/queue.py index 7f6b5872..6a77c8b0 100644 --- a/src/torchio/data/queue.py +++ b/src/torchio/data/queue.py @@ -115,7 +115,6 @@ class Queue(Dataset): >>> import torch >>> import torchio as tio - >>> from torch.utils.data import DataLoader >>> patch_size = 96 >>> queue_length = 300 >>> samples_per_volume = 10 @@ -129,7 +128,7 @@ class Queue(Dataset): ... sampler, ... num_workers=4, ... ) - >>> patches_loader = DataLoader( + >>> patches_loader = tio.SubjectsLoader( ... patches_queue, ... batch_size=16, ... num_workers=0, # this must be 0 @@ -168,7 +167,7 @@ class Queue(Dataset): ... num_workers=4, ... subject_sampler=subject_sampler, ... ) - >>> patches_loader = DataLoader( + >>> patches_loader = tio.SubjectsLoader( ... patches_queue, ... batch_size=16, ... num_workers=0, # this must be 0 diff --git a/src/torchio/data/subject.py b/src/torchio/data/subject.py index 508dd222..767e1c60 100644 --- a/src/torchio/data/subject.py +++ b/src/torchio/data/subject.py @@ -442,7 +442,7 @@ def _subject_copy_helper( value = copy.deepcopy(value) result_dict[key] = value - new = new_subj_cls(**result_dict) + new = new_subj_cls(**result_dict) # type: ignore[call-arg] new.applied_transforms = old_obj.applied_transforms[:] return new diff --git a/src/torchio/transforms/preprocessing/intensity/rescale.py b/src/torchio/transforms/preprocessing/intensity/rescale.py index 012cf22e..cdbdaee3 100644 --- a/src/torchio/transforms/preprocessing/intensity/rescale.py +++ b/src/torchio/transforms/preprocessing/intensity/rescale.py @@ -4,6 +4,7 @@ import numpy as np import torch +from ....data.image import Image from ....data.subject import Subject from ....typing import TypeDoubleFloat from .normalization_transform import NormalizationTransform @@ -84,7 +85,7 @@ def apply_normalization( image_name: str, mask: torch.Tensor, ) -> None: - image = subject[image_name] + image: Image = subject[image_name] image.set_data(self.rescale(image.data, mask, image_name)) def rescale( @@ -95,8 +96,8 @@ def rescale( ) -> torch.Tensor: # The tensor is cloned as in-place operations will be used array = tensor.clone().float().numpy() - mask = mask.numpy() - if not mask.any(): + mask_array = mask.numpy() + if not mask_array.any(): message = ( f'Rescaling image "{image_name}" not possible' ' because the mask to compute the statistics is empty' @@ -104,7 +105,7 @@ def rescale( warnings.warn(message, RuntimeWarning, stacklevel=2) return tensor - values = array[mask] + values = array[mask_array] cutoff = np.percentile(values, self.percentiles) np.clip(array, *cutoff, out=array) # type: ignore[call-overload] diff --git a/src/torchio/utils.py b/src/torchio/utils.py index 4e65aef2..6f7eecc4 100644 --- a/src/torchio/utils.py +++ b/src/torchio/utils.py @@ -20,6 +20,7 @@ import numpy as np import SimpleITK as sitk import torch +from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate from tqdm.auto import trange @@ -239,7 +240,7 @@ def get_subclasses(target_class: type) -> List[type]: return subclasses -def get_first_item(data_loader: torch.utils.data.DataLoader): +def get_first_item(data_loader: DataLoader): return next(iter(data_loader)) @@ -247,7 +248,7 @@ def get_batch_images_and_size(batch: Dict) -> Tuple[List[str], int]: """Get number of images and images names in a batch. Args: - batch: Dictionary generated by a :class:`torch.utils.data.DataLoader` + batch: Dictionary generated by a :class:`tio.SubjectsLoader` extracting data from a :class:`torchio.SubjectsDataset`. Raises: @@ -268,7 +269,7 @@ def get_subjects_from_batch(batch: Dict) -> List: """Get list of subjects from collated batch. Args: - batch: Dictionary generated by a :class:`torch.utils.data.DataLoader` + batch: Dictionary generated by a :class:`tio.SubjectsLoader` extracting data from a :class:`torchio.SubjectsDataset`. """ from .data import ScalarImage, LabelMap, Subject diff --git a/tests/data/inference/test_aggregator.py b/tests/data/inference/test_aggregator.py index 35737477..b1523ff8 100644 --- a/tests/data/inference/test_aggregator.py +++ b/tests/data/inference/test_aggregator.py @@ -18,7 +18,7 @@ def aggregate(self, mode, fixture): patch_overlap = 0, 2, 2 sampler = tio.data.GridSampler(subject, patch_size, patch_overlap) aggregator = tio.data.GridAggregator(sampler, overlap_mode=mode) - loader = torch.utils.data.DataLoader(sampler, batch_size=3) + loader = tio.SubjectsLoader(sampler, batch_size=3) values_dict = { (0, 0): 0, (0, 1): 2, @@ -70,7 +70,7 @@ def run_sampler_aggregator(self, overlap_mode='crop'): patch_size, patch_overlap, ) - patch_loader = torch.utils.data.DataLoader(grid_sampler) + patch_loader = tio.SubjectsLoader(grid_sampler) aggregator = tio.inference.GridAggregator( grid_sampler, overlap_mode=overlap_mode, @@ -102,7 +102,7 @@ def run_patch_crop_issue(self, *, padding_mode): patch_size, patch_overlap, ) - patch_loader = torch.utils.data.DataLoader(grid_sampler) + patch_loader = tio.SubjectsLoader(grid_sampler) aggregator = tio.inference.GridAggregator(grid_sampler) for patches_batch in patch_loader: input_tensor = patches_batch['image'][tio.DATA] @@ -131,7 +131,7 @@ def test_bad_aggregator_shape(self): padding_mode='edge', ) aggregator = tio.data.GridAggregator(sampler) - loader = torch.utils.data.DataLoader(sampler, batch_size=3) + loader = tio.SubjectsLoader(sampler, batch_size=3) for batch in loader: input_batch = batch[image_name][tio.DATA] crop = tio.CropOrPad(12) diff --git a/tests/data/inference/test_inference.py b/tests/data/inference/test_inference.py index 71abfdc6..8f20ed5d 100644 --- a/tests/data/inference/test_inference.py +++ b/tests/data/inference/test_inference.py @@ -1,4 +1,4 @@ -from torch.utils.data import DataLoader +import torchio as tio from torchio import DATA from torchio import LOCATION from torchio.data.inference import GridAggregator @@ -29,7 +29,7 @@ def try_inference(self, padding_mode): padding_mode=padding_mode, ) aggregator = GridAggregator(grid_sampler) - patch_loader = DataLoader(grid_sampler, batch_size=batch_size) + patch_loader = tio.SubjectsLoader(grid_sampler, batch_size=batch_size) for patches_batch in patch_loader: input_tensor = patches_batch['t1'][DATA] locations = patches_batch[LOCATION] diff --git a/tests/data/test_queue.py b/tests/data/test_queue.py index 2bfd18f3..1e6ae49d 100644 --- a/tests/data/test_queue.py +++ b/tests/data/test_queue.py @@ -4,7 +4,6 @@ import torch import torchio as tio from parameterized import parameterized -from torch.utils.data import DataLoader from torchio.data import UniformSampler from torchio.utils import create_dummy_dataset @@ -37,7 +36,7 @@ def run_queue(self, num_workers=0, **kwargs): **kwargs, ) _ = str(queue_dataset) - batch_loader = DataLoader(queue_dataset, batch_size=4) + batch_loader = tio.SubjectsLoader(queue_dataset, batch_size=4) for batch in batch_loader: _ = batch['one_modality'][tio.DATA] _ = batch['segmentation'][tio.DATA] @@ -69,7 +68,7 @@ def test_different_samples_per_volume(self, max_length): sampler=sampler, shuffle_patches=False, ) - batch_loader = DataLoader(queue_dataset, batch_size=6) + batch_loader = tio.SubjectsLoader(queue_dataset, batch_size=6) tensors = [batch['im'][tio.DATA] for batch in batch_loader] all_numbers = torch.stack(tensors).flatten().tolist() assert all_numbers.count(10) == 10 diff --git a/tests/data/test_subject.py b/tests/data/test_subject.py index e94fbb98..8a9cb747 100644 --- a/tests/data/test_subject.py +++ b/tests/data/test_subject.py @@ -6,7 +6,6 @@ import pytest import torch import torchio as tio -from torch.utils.data import DataLoader from ..utils import TorchioTestCase @@ -183,6 +182,6 @@ def test_load_unload(self): def test_subjects_batch(self): subjects = tio.SubjectsDataset(10 * [self.sample_subject]) - loader = DataLoader(subjects, batch_size=4) + loader = tio.SubjectsLoader(subjects, batch_size=4) batch = next(iter(loader)) assert batch.__class__ is dict diff --git a/tests/data/test_subjects_dataset.py b/tests/data/test_subjects_dataset.py index 2c5fd231..9187238b 100644 --- a/tests/data/test_subjects_dataset.py +++ b/tests/data/test_subjects_dataset.py @@ -1,7 +1,6 @@ import pytest import torch import torchio as tio -from torch.utils.data import DataLoader from ..utils import TorchioTestCase @@ -57,7 +56,7 @@ def iterate_dataset(subjects_list): def test_from_batch(self): dataset = tio.SubjectsDataset([self.sample_subject]) - loader = DataLoader(dataset) + loader = tio.SubjectsLoader(dataset) batch = tio.utils.get_first_item(loader) new_dataset = tio.SubjectsDataset.from_batch(batch) self.assert_tensor_equal( diff --git a/tests/datasets/test_medmnist.py b/tests/datasets/test_medmnist.py index 2d527877..680eec49 100644 --- a/tests/datasets/test_medmnist.py +++ b/tests/datasets/test_medmnist.py @@ -1,7 +1,7 @@ import os import pytest -import torch +import torchio as tio from torchio.datasets.medmnist import AdrenalMNIST3D from torchio.datasets.medmnist import FractureMNIST3D from torchio.datasets.medmnist import NoduleMNIST3D @@ -26,7 +26,7 @@ @pytest.mark.parametrize('split', ('train', 'val', 'test')) def test_load_all(class_, split): dataset = class_(split) - loader = torch.utils.data.DataLoader( + loader = tio.SubjectsLoader( dataset, batch_size=256, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7bda672c..6070896a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -43,7 +43,7 @@ def test_apply_transform_to_file(self): def test_subjects_from_batch(self): dataset = tio.SubjectsDataset(4 * [self.sample_subject]) - loader = torch.utils.data.DataLoader(dataset, batch_size=4) + loader = tio.SubjectsLoader(dataset, batch_size=4) batch = tio.utils.get_first_item(loader) subjects = tio.utils.get_subjects_from_batch(batch) assert isinstance(subjects[0], tio.Subject) @@ -55,7 +55,7 @@ def test_subjects_from_batch_with_string_metadata(self): ) dataset = tio.SubjectsDataset(4 * [subject_c_with_string_metadata]) - loader = torch.utils.data.DataLoader(dataset, batch_size=4) + loader = tio.SubjectsLoader(dataset, batch_size=4) batch = tio.utils.get_first_item(loader) subjects = tio.utils.get_subjects_from_batch(batch) assert isinstance(subjects[0], tio.Subject) @@ -68,7 +68,7 @@ def test_subjects_from_batch_with_int_metadata(self): label=tio.LabelMap(self.get_image_path('label_c', binary=True)), ) dataset = tio.SubjectsDataset(4 * [subject_c_with_int_metadata]) - loader = torch.utils.data.DataLoader(dataset, batch_size=4) + loader = tio.SubjectsLoader(dataset, batch_size=4) batch = tio.utils.get_first_item(loader) subjects = tio.utils.get_subjects_from_batch(batch) assert isinstance(subjects[0], tio.Subject) diff --git a/tests/transforms/test_collate.py b/tests/transforms/test_collate.py index eed377d6..25e95ebe 100644 --- a/tests/transforms/test_collate.py +++ b/tests/transforms/test_collate.py @@ -1,5 +1,4 @@ import torchio as tio -from torch.utils.data import DataLoader from ..utils import TorchioTestCase @@ -28,11 +27,11 @@ def __getitem__(self, index): return Dataset(data) def test_collate(self): - loader = DataLoader(self.get_heterogeneous_dataset(), batch_size=2) + loader = tio.SubjectsLoader(self.get_heterogeneous_dataset(), batch_size=2) tio.utils.get_first_item(loader) def test_history_collate(self): - loader = DataLoader( + loader = tio.SubjectsLoader( self.get_heterogeneous_dataset(), batch_size=4, collate_fn=tio.utils.history_collate, diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 9fdf5b83..4c424c0e 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -291,7 +291,7 @@ def test_batch_history(self): ] ) dataset = tio.SubjectsDataset([subject], transform=transform) - loader = torch.utils.data.DataLoader( + loader = tio.SubjectsLoader( dataset, collate_fn=tio.utils.history_collate, ) diff --git a/tutorials/example_heteromodal.py b/tutorials/example_heteromodal.py index a50de35f..4a3fa06a 100644 --- a/tutorials/example_heteromodal.py +++ b/tutorials/example_heteromodal.py @@ -8,7 +8,7 @@ import logging import torch.nn as nn -from torch.utils.data import DataLoader +import torchio as tio from torchio import LabelMap from torchio import Queue from torchio import ScalarImage @@ -54,7 +54,7 @@ def main(): # This collate_fn is needed in the case of missing modalities # In this case, the batch will be composed by a *list* of samples instead # of the typical Python dictionary that is collated by default in Pytorch - batch_loader = DataLoader( + batch_loader = tio.SubjectsLoader( queue_dataset, batch_size=batch_size, collate_fn=lambda x: x,