diff --git a/src/torchio/transforms/augmentation/composition.py b/src/torchio/transforms/augmentation/composition.py index 16bd6305..6412aaa0 100644 --- a/src/torchio/transforms/augmentation/composition.py +++ b/src/torchio/transforms/augmentation/composition.py @@ -47,11 +47,8 @@ def __repr__(self) -> str: def get_init_args(self) -> Dict: init_args = super().get_init_args() - try: - # Remove parse_input as it is set to False in the __init__ + if 'parse_input' in init_args: init_args.pop('parse_input') - except KeyError: - pass return init_args def apply_transform(self, subject: Subject) -> Subject: @@ -115,11 +112,8 @@ def __init__(self, transforms: TypeTransformsDict, **kwargs): def get_init_args(self) -> Dict: init_args = super().get_init_args() - try: - # Remove parse_input as it is set to False in the __init__ + if 'parse_input' in init_args: init_args.pop('parse_input') - except KeyError: - pass return init_args def apply_transform(self, subject: Subject) -> Subject: diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 4a27c7a9..d505e15e 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -381,3 +381,27 @@ def test_bad_keys_type(self): # From https://github.com/fepegar/torchio/issues/923 with self.assertRaises(ValueError): tio.RandomAffine(include='t1') + + def test_init_args(self): + transform = tio.Compose([tio.RandomNoise()]) + init_args = transform.get_init_args() + assert 'parse_input' not in init_args + + transform = tio.OneOf([tio.RandomNoise()]) + init_args = transform.get_init_args() + assert 'parse_input' not in init_args + + transform = tio.RandomNoise() + init_args = transform.get_init_args() + assert all( + arg in init_args + for arg in [ + 'p', + 'copy', + 'include', + 'exclude', + 'keep', + 'parse_input', + 'label_keys', + ] + )