diff --git a/src/torchio/utils.py b/src/torchio/utils.py index c7ebe97c..8c5034bc 100644 --- a/src/torchio/utils.py +++ b/src/torchio/utils.py @@ -248,17 +248,17 @@ def get_batch_images_and_size(batch: Dict) -> Tuple[List[str], int]: Args: batch: Dictionary generated by a :class:`torch.utils.data.DataLoader` - extracting data from a :class:`torchio.SubjectsDataset`. + extracting data from a :class:`torchio.SubjectsDataset`. Raises: RuntimeError: If the batch does not seem to contain any dictionaries - that seem to represent a :class:`torchio.Image`. + that seem to represent a :class:`torchio.Image`. """ names = [] - for image_name, image_dict in batch.items(): - if constants.DATA in image_dict: # assume it is a TorchIO Image - size = len(image_dict[constants.DATA]) - names.append(image_name) + for key, value in batch.items(): + if isinstance(value, dict) and constants.DATA in value: + size = len(value[constants.DATA]) + names.append(key) if not names: raise RuntimeError('The batch does not seem to contain any images') return names, size @@ -269,28 +269,38 @@ def get_subjects_from_batch(batch: Dict) -> List: Args: batch: Dictionary generated by a :class:`torch.utils.data.DataLoader` - extracting data from a :class:`torchio.SubjectsDataset`. + extracting data from a :class:`torchio.SubjectsDataset`. """ from .data import ScalarImage, LabelMap, Subject subjects = [] image_names, batch_size = get_batch_images_and_size(batch) + for i in range(batch_size): subject_dict = {} - for image_name in image_names: - image_dict = batch[image_name] - data = image_dict[constants.DATA][i] - affine = image_dict[constants.AFFINE][i] - path = Path(image_dict[constants.PATH][i]) - is_label = image_dict[constants.TYPE][i] == constants.LABEL - klass = LabelMap if is_label else ScalarImage - image = klass(tensor=data, affine=affine, filename=path.name) - subject_dict[image_name] = image + + for key, value in batch.items(): + if key in image_names: + image_name = key + image_dict = value + data = image_dict[constants.DATA][i] + affine = image_dict[constants.AFFINE][i] + path = Path(image_dict[constants.PATH][i]) + is_label = image_dict[constants.TYPE][i] == constants.LABEL + klass = LabelMap if is_label else ScalarImage + image = klass(tensor=data, affine=affine, filename=path.name) + subject_dict[image_name] = image + else: + instance_value = value[i] + subject_dict[key] = instance_value + subject = Subject(subject_dict) + if constants.HISTORY in batch: applied_transforms = batch[constants.HISTORY][i] for transform in applied_transforms: transform.add_transform_to_subject_history(subject) + subjects.append(subject) return subjects