From f91f1b45c15c810637cc9d3da94905e5047793fa Mon Sep 17 00:00:00 2001 From: KonoMaxi Date: Sun, 28 Jan 2024 19:25:26 +0100 Subject: [PATCH] Fix get_subjects_from_batch ignoring metadata (#1131) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added test that shows metadata issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix get_subjects_from_batch ignoring metadata --------- Co-authored-by: Max Konowski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Fernando Pérez-García --- src/torchio/utils.py | 42 ++++++++++++++++++++++++++---------------- tests/test_utils.py | 27 +++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/src/torchio/utils.py b/src/torchio/utils.py index c7ebe97c4..8c5034bc8 100644 --- a/src/torchio/utils.py +++ b/src/torchio/utils.py @@ -248,17 +248,17 @@ def get_batch_images_and_size(batch: Dict) -> Tuple[List[str], int]: Args: batch: Dictionary generated by a :class:`torch.utils.data.DataLoader` - extracting data from a :class:`torchio.SubjectsDataset`. + extracting data from a :class:`torchio.SubjectsDataset`. Raises: RuntimeError: If the batch does not seem to contain any dictionaries - that seem to represent a :class:`torchio.Image`. + that seem to represent a :class:`torchio.Image`. """ names = [] - for image_name, image_dict in batch.items(): - if constants.DATA in image_dict: # assume it is a TorchIO Image - size = len(image_dict[constants.DATA]) - names.append(image_name) + for key, value in batch.items(): + if isinstance(value, dict) and constants.DATA in value: + size = len(value[constants.DATA]) + names.append(key) if not names: raise RuntimeError('The batch does not seem to contain any images') return names, size @@ -269,28 +269,38 @@ def get_subjects_from_batch(batch: Dict) -> List: Args: batch: Dictionary generated by a :class:`torch.utils.data.DataLoader` - extracting data from a :class:`torchio.SubjectsDataset`. + extracting data from a :class:`torchio.SubjectsDataset`. """ from .data import ScalarImage, LabelMap, Subject subjects = [] image_names, batch_size = get_batch_images_and_size(batch) + for i in range(batch_size): subject_dict = {} - for image_name in image_names: - image_dict = batch[image_name] - data = image_dict[constants.DATA][i] - affine = image_dict[constants.AFFINE][i] - path = Path(image_dict[constants.PATH][i]) - is_label = image_dict[constants.TYPE][i] == constants.LABEL - klass = LabelMap if is_label else ScalarImage - image = klass(tensor=data, affine=affine, filename=path.name) - subject_dict[image_name] = image + + for key, value in batch.items(): + if key in image_names: + image_name = key + image_dict = value + data = image_dict[constants.DATA][i] + affine = image_dict[constants.AFFINE][i] + path = Path(image_dict[constants.PATH][i]) + is_label = image_dict[constants.TYPE][i] == constants.LABEL + klass = LabelMap if is_label else ScalarImage + image = klass(tensor=data, affine=affine, filename=path.name) + subject_dict[image_name] = image + else: + instance_value = value[i] + subject_dict[key] = instance_value + subject = Subject(subject_dict) + if constants.HISTORY in batch: applied_transforms = batch[constants.HISTORY][i] for transform in applied_transforms: transform.add_transform_to_subject_history(subject) + subjects.append(subject) return subjects diff --git a/tests/test_utils.py b/tests/test_utils.py index 676c29735..7bda672c0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -48,6 +48,33 @@ def test_subjects_from_batch(self): subjects = tio.utils.get_subjects_from_batch(batch) assert isinstance(subjects[0], tio.Subject) + def test_subjects_from_batch_with_string_metadata(self): + subject_c_with_string_metadata = tio.Subject( + name='John Doe', + label=tio.LabelMap(self.get_image_path('label_c', binary=True)), + ) + + dataset = tio.SubjectsDataset(4 * [subject_c_with_string_metadata]) + loader = torch.utils.data.DataLoader(dataset, batch_size=4) + batch = tio.utils.get_first_item(loader) + subjects = tio.utils.get_subjects_from_batch(batch) + assert isinstance(subjects[0], tio.Subject) + assert 'label' in subjects[0] + assert 'name' in subjects[0] + + def test_subjects_from_batch_with_int_metadata(self): + subject_c_with_int_metadata = tio.Subject( + age=45, + label=tio.LabelMap(self.get_image_path('label_c', binary=True)), + ) + dataset = tio.SubjectsDataset(4 * [subject_c_with_int_metadata]) + loader = torch.utils.data.DataLoader(dataset, batch_size=4) + batch = tio.utils.get_first_item(loader) + subjects = tio.utils.get_subjects_from_batch(batch) + assert isinstance(subjects[0], tio.Subject) + assert 'label' in subjects[0] + assert 'age' in subjects[0] + def test_add_images_from_batch(self): subject = copy.deepcopy(self.sample_subject) subjects = 4 * [subject]