Skip to content

Commit

Permalink
Fix get_subjects_from_batch ignoring metadata (#1131)
Browse files Browse the repository at this point in the history
* added test that shows metadata issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix get_subjects_from_batch ignoring metadata

---------

Co-authored-by: Max Konowski <max.konowski@uni-muenster.de>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Fernando Pérez-García <fepegar@gmail.com>
  • Loading branch information
4 people authored Jan 28, 2024
1 parent 08cd1a3 commit f91f1b4
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
42 changes: 26 additions & 16 deletions src/torchio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,17 @@ def get_batch_images_and_size(batch: Dict) -> Tuple[List[str], int]:
Args:
batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
extracting data from a :class:`torchio.SubjectsDataset`.
extracting data from a :class:`torchio.SubjectsDataset`.
Raises:
RuntimeError: If the batch does not seem to contain any dictionaries
that seem to represent a :class:`torchio.Image`.
that seem to represent a :class:`torchio.Image`.
"""
names = []
for image_name, image_dict in batch.items():
if constants.DATA in image_dict: # assume it is a TorchIO Image
size = len(image_dict[constants.DATA])
names.append(image_name)
for key, value in batch.items():
if isinstance(value, dict) and constants.DATA in value:
size = len(value[constants.DATA])
names.append(key)
if not names:
raise RuntimeError('The batch does not seem to contain any images')
return names, size
Expand All @@ -269,28 +269,38 @@ def get_subjects_from_batch(batch: Dict) -> List:
Args:
batch: Dictionary generated by a :class:`torch.utils.data.DataLoader`
extracting data from a :class:`torchio.SubjectsDataset`.
extracting data from a :class:`torchio.SubjectsDataset`.
"""
from .data import ScalarImage, LabelMap, Subject

subjects = []
image_names, batch_size = get_batch_images_and_size(batch)

for i in range(batch_size):
subject_dict = {}
for image_name in image_names:
image_dict = batch[image_name]
data = image_dict[constants.DATA][i]
affine = image_dict[constants.AFFINE][i]
path = Path(image_dict[constants.PATH][i])
is_label = image_dict[constants.TYPE][i] == constants.LABEL
klass = LabelMap if is_label else ScalarImage
image = klass(tensor=data, affine=affine, filename=path.name)
subject_dict[image_name] = image

for key, value in batch.items():
if key in image_names:
image_name = key
image_dict = value
data = image_dict[constants.DATA][i]
affine = image_dict[constants.AFFINE][i]
path = Path(image_dict[constants.PATH][i])
is_label = image_dict[constants.TYPE][i] == constants.LABEL
klass = LabelMap if is_label else ScalarImage
image = klass(tensor=data, affine=affine, filename=path.name)
subject_dict[image_name] = image
else:
instance_value = value[i]
subject_dict[key] = instance_value

subject = Subject(subject_dict)

if constants.HISTORY in batch:
applied_transforms = batch[constants.HISTORY][i]
for transform in applied_transforms:
transform.add_transform_to_subject_history(subject)

subjects.append(subject)
return subjects

Expand Down
27 changes: 27 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def test_subjects_from_batch(self):
subjects = tio.utils.get_subjects_from_batch(batch)
assert isinstance(subjects[0], tio.Subject)

def test_subjects_from_batch_with_string_metadata(self):
subject_c_with_string_metadata = tio.Subject(
name='John Doe',
label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
)

dataset = tio.SubjectsDataset(4 * [subject_c_with_string_metadata])
loader = torch.utils.data.DataLoader(dataset, batch_size=4)
batch = tio.utils.get_first_item(loader)
subjects = tio.utils.get_subjects_from_batch(batch)
assert isinstance(subjects[0], tio.Subject)
assert 'label' in subjects[0]
assert 'name' in subjects[0]

def test_subjects_from_batch_with_int_metadata(self):
subject_c_with_int_metadata = tio.Subject(
age=45,
label=tio.LabelMap(self.get_image_path('label_c', binary=True)),
)
dataset = tio.SubjectsDataset(4 * [subject_c_with_int_metadata])
loader = torch.utils.data.DataLoader(dataset, batch_size=4)
batch = tio.utils.get_first_item(loader)
subjects = tio.utils.get_subjects_from_batch(batch)
assert isinstance(subjects[0], tio.Subject)
assert 'label' in subjects[0]
assert 'age' in subjects[0]

def test_add_images_from_batch(self):
subject = copy.deepcopy(self.sample_subject)
subjects = 4 * [subject]
Expand Down

0 comments on commit f91f1b4

Please sign in to comment.