Skip to content

Commit

Permalink
Add warning if SubjectsLoader is not used in PyTorch >= 2.3 (#1215)
Browse files Browse the repository at this point in the history
* Add warning if SubjectsLoader is not used

* Improve warning and move to method

Co-authored-by: valabregue <romain.valabregue@upmc.fr>
  • Loading branch information
fepegar and romainVala authored Sep 26, 2024
1 parent 0314926 commit 5df1638
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..typing import TypeTripletInt
from ..utils import get_stem
from ..utils import guess_external_viewer
from ..utils import in_torch_loader
from ..utils import is_iterable
from ..utils import to_tuple
from .io import check_uint_to_int
Expand Down Expand Up @@ -176,6 +177,7 @@ def __init__(
warnings.warn(message, DeprecationWarning, stacklevel=2)

super().__init__(**kwargs)
self._check_data_loader()
self.path = self._parse_path(path)

self[PATH] = '' if self.path is None else str(self.path)
Expand Down Expand Up @@ -234,6 +236,20 @@ def __copy__(self):
)
return new_image

@staticmethod
def _check_data_loader() -> None:
if torch.__version__ >= '2.3' and in_torch_loader():
message = (
'Using TorchIO images without a torchio.SubjectsLoader in PyTorch >='
' 2.3 might have unexpected consequences, e.g., the collated batches'
' will be instances of torchio.Subject with 5D images. Replace'
' your PyTorch DataLoader with a torchio.SubjectsLoader so that'
' the collated batch becomes a dictionary, as expected. See'
' https://github.com/fepegar/torchio/issues/1179 for more'
' context about this issue.'
)
warnings.warn(message, stacklevel=1)

@property
def data(self) -> torch.Tensor:
"""Tensor data (same as :class:`Image.tensor`)."""
Expand Down
23 changes: 23 additions & 0 deletions src/torchio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import gzip
import inspect
import os
import shutil
import sys
Expand All @@ -19,6 +20,7 @@
import numpy as np
import SimpleITK as sitk
import torch
import torch.utils.data.dataloader
from nibabel.nifti1 import Nifti1Image
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
Expand Down Expand Up @@ -415,3 +417,24 @@ def is_iterable(object: Any) -> bool:
return True
except TypeError:
return False


def in_class(classes) -> bool:
classes = to_tuple(classes)
stack = inspect.stack()
for frame_info in stack:
instance = frame_info.frame.f_locals.get('self')
if instance is None:
continue
if instance.__class__ in classes:
return True
else:
return False


def in_torch_loader() -> bool:
classes = (
torch.utils.data.dataloader._SingleProcessDataLoaderIter,
torch.utils.data.dataloader._MultiProcessingDataLoaderIter,
)
return in_class(classes)

0 comments on commit 5df1638

Please sign in to comment.