From 85457504626c162f1bedb11a6293321c27a087b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Wed, 25 Sep 2024 22:38:51 +0100 Subject: [PATCH 1/2] Add warning if SubjectsLoader is not used --- src/torchio/data/image.py | 10 ++++++++++ src/torchio/utils.py | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/torchio/data/image.py b/src/torchio/data/image.py index 2f980e75..568bf5c9 100644 --- a/src/torchio/data/image.py +++ b/src/torchio/data/image.py @@ -36,6 +36,7 @@ from ..typing import TypeTripletInt from ..utils import get_stem from ..utils import guess_external_viewer +from ..utils import in_torch_loader from ..utils import is_iterable from ..utils import to_tuple from .io import check_uint_to_int @@ -176,6 +177,15 @@ def __init__( warnings.warn(message, DeprecationWarning, stacklevel=2) super().__init__(**kwargs) + if torch.__version__ >= '2.3' and in_torch_loader(): + message = ( + 'Using TorchIO images without a SubjectsLoader in PyTorch >=' + ' 2.3 might have unexpected consequences. Please replace your' + ' PyTorch DataLoader with a SubjectsLoader. See' + ' https://github.com/fepegar/torchio/issues/1179 for more' + ' context about this problem' + ) + warnings.warn(message, stacklevel=1) self.path = self._parse_path(path) self[PATH] = '' if self.path is None else str(self.path) diff --git a/src/torchio/utils.py b/src/torchio/utils.py index 436536cc..e206af2e 100644 --- a/src/torchio/utils.py +++ b/src/torchio/utils.py @@ -2,6 +2,7 @@ import ast import gzip +import inspect import os import shutil import sys @@ -19,6 +20,7 @@ import numpy as np import SimpleITK as sitk import torch +import torch.utils.data.dataloader from nibabel.nifti1 import Nifti1Image from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate @@ -415,3 +417,24 @@ def is_iterable(object: Any) -> bool: return True except TypeError: return False + + +def in_class(classes) -> bool: + classes = to_tuple(classes) + stack = inspect.stack() + for frame_info in stack: + instance = frame_info.frame.f_locals.get('self') + if instance is None: + continue + if instance.__class__ in classes: + return True + else: + return False + + +def in_torch_loader() -> bool: + classes = ( + torch.utils.data.dataloader._SingleProcessDataLoaderIter, + torch.utils.data.dataloader._MultiProcessingDataLoaderIter, + ) + return in_class(classes) From 8408171cd0ecdecafc3099bc8f53ba899f919ff4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Thu, 26 Sep 2024 23:30:37 +0100 Subject: [PATCH 2/2] Improve warning and move to method --- src/torchio/data/image.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/torchio/data/image.py b/src/torchio/data/image.py index 568bf5c9..a48b21d6 100644 --- a/src/torchio/data/image.py +++ b/src/torchio/data/image.py @@ -177,15 +177,7 @@ def __init__( warnings.warn(message, DeprecationWarning, stacklevel=2) super().__init__(**kwargs) - if torch.__version__ >= '2.3' and in_torch_loader(): - message = ( - 'Using TorchIO images without a SubjectsLoader in PyTorch >=' - ' 2.3 might have unexpected consequences. Please replace your' - ' PyTorch DataLoader with a SubjectsLoader. See' - ' https://github.com/fepegar/torchio/issues/1179 for more' - ' context about this problem' - ) - warnings.warn(message, stacklevel=1) + self._check_data_loader() self.path = self._parse_path(path) self[PATH] = '' if self.path is None else str(self.path) @@ -244,6 +236,20 @@ def __copy__(self): ) return new_image + @staticmethod + def _check_data_loader() -> None: + if torch.__version__ >= '2.3' and in_torch_loader(): + message = ( + 'Using TorchIO images without a torchio.SubjectsLoader in PyTorch >=' + ' 2.3 might have unexpected consequences, e.g., the collated batches' + ' will be instances of torchio.Subject with 5D images. Replace' + ' your PyTorch DataLoader with a torchio.SubjectsLoader so that' + ' the collated batch becomes a dictionary, as expected. See' + ' https://github.com/fepegar/torchio/issues/1179 for more' + ' context about this issue.' + ) + warnings.warn(message, stacklevel=1) + @property def data(self) -> torch.Tensor: """Tensor data (same as :class:`Image.tensor`)."""