Skip to content

Commit

Permalink
Set transform kwargs in every instantiation of another transform with…
Browse files Browse the repository at this point in the history
…in the original transform
  • Loading branch information
nicoloesch committed Oct 24, 2024
1 parent cfc8aaf commit 980cb9c
Show file tree
Hide file tree
Showing 20 changed files with 34 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/torchio/transforms/augmentation/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def inverse(self, warn: bool = True) -> Compose:
message = f'Skipping {transform.name} as it is not invertible'
warnings.warn(message, RuntimeWarning, stacklevel=2)
transforms.reverse()
result = Compose(transforms)
result = Compose(transforms, **self.get_init_args())
if not transforms and warn:
warnings.warn(
'No invertible transforms found',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def apply_transform(self, subject: Subject) -> Subject:
coefficients = self.get_params(self.order, self.coefficients_range)
arguments['coefficients'][image_name] = coefficients
arguments['order'][image_name] = self.order
transform = BiasField(**self.add_include_exclude(arguments))
transform = BiasField(**arguments, **self.get_init_args())
transformed = transform(subject)
return transformed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def apply_transform(self, subject: Subject) -> Subject:
for name in images_dict:
std = self.get_params(self.std_ranges) # type: ignore[arg-type]
arguments['std'][name] = std
transform = Blur(**self.add_include_exclude(arguments))
transform = Blur(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def apply_transform(self, subject: Subject) -> Subject:
for name, image in images_dict.items():
gammas = [self.get_params(self.log_gamma_range) for _ in image.data]
arguments['gamma'][name] = gammas
transform = Gamma(**self.add_include_exclude(arguments))
transform = Gamma(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def apply_transform(self, subject: Subject) -> Subject:
arguments['axis'][name] = axis_param
arguments['intensity'][name] = intensity_param
arguments['restore'][name] = self.restore
transform = Ghosting(**self.add_include_exclude(arguments))
transform = Ghosting(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def apply_transform(self, subject: Subject) -> Subject:
means.append(mean)
stds.append(std)

transform = LabelsToImage(**self.add_include_exclude(arguments))
transform = LabelsToImage(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def apply_transform(self, subject: Subject) -> Subject:
arguments['degrees'][name] = degrees_params
arguments['translation'][name] = translation_params
arguments['image_interpolation'][name] = self.image_interpolation
transform = Motion(**self.add_include_exclude(arguments))
transform = Motion(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def apply_transform(self, subject: Subject) -> Subject:
arguments['mean'][image_name] = mean
arguments['std'][image_name] = std
arguments['seed'][image_name] = seed
transform = Noise(**self.add_include_exclude(arguments))
transform = Noise(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def apply_transform(self, subject: Subject) -> Subject:
)
arguments['spikes_positions'][image_name] = spikes_positions_param
arguments['intensity'][image_name] = intensity_param
transform = Spike(**self.add_include_exclude(arguments))
transform = Spike(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def apply_transform(self, subject: Subject) -> Subject:
)
arguments['locations'][name] = locations
arguments['patch_size'][name] = self.patch_size
transform = Swap(**self.add_include_exclude(arguments))
transform = Swap(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
5 changes: 0 additions & 5 deletions src/torchio/transforms/augmentation/random_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ class RandomTransform(Transform):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def add_include_exclude(self, kwargs):
kwargs['include'] = self.include
kwargs['exclude'] = self.exclude
return kwargs

def parse_degrees(
self,
degrees: TypeRangeFloat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def apply_transform(self, subject: Subject) -> Subject:
'label_interpolation': self.label_interpolation,
'check_shape': self.check_shape,
}
transform = Affine(**self.add_include_exclude(arguments))
transform = Affine(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def apply_transform(self, subject: Subject) -> Subject:
}

sx, sy, sz = target_spacing # for mypy
downsample = Resample(
target=(sx, sy, sz), **self.add_include_exclude(arguments)
)
downsample = Resample(target=(sx, sy, sz), **arguments, **self.get_init_args())
downsampled = downsample(subject)
image = subject.get_first_image()
target = image.spatial_shape, image.affine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def apply_transform(self, subject: Subject) -> Subject:
'label_interpolation': self.label_interpolation,
}

transform = ElasticDeformation(**self.add_include_exclude(arguments))
transform = ElasticDeformation(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
2 changes: 1 addition & 1 deletion src/torchio/transforms/augmentation/spatial/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def apply_transform(self, subject: Subject) -> Subject:
return subject

arguments = {'axes': axes}
transform = Flip(**self.add_include_exclude(arguments))
transform = Flip(**arguments, **self.get_init_args())
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def apply_transform(self, subject):
remapping = {
unique_labels[i].item(): i for i in range(0, len(unique_labels))
}
init_kwargs = self.get_init_args()
init_kwargs['include'] = [name]

transform = RemapLabels(
remapping=remapping,
masking_method=self.masking_method,
include=[name],
**init_kwargs,
)
subject = transform(subject)
return subject
4 changes: 2 additions & 2 deletions src/torchio/transforms/preprocessing/spatial/crop_or_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ def apply_transform(self, subject: Subject) -> Subject:
padding_params, cropping_params = self.compute_crop_or_pad(subject)
padding_kwargs = {'padding_mode': self.padding_mode}
if padding_params is not None:
pad = Pad(padding_params, **padding_kwargs)
pad = Pad(padding_params, **self.get_init_args(), **padding_kwargs)
subject = pad(subject) # type: ignore[assignment]
if cropping_params is not None:
crop = Crop(cropping_params)
crop = Crop(cropping_params, **self.get_init_args())
subject = crop(subject) # type: ignore[assignment]
return subject
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ def apply_transform(self, subject: Subject) -> Subject:
integer_ratio = function(source_shape / self.target_multiple)
target_shape = integer_ratio * self.target_multiple
target_shape = np.maximum(target_shape, 1)
transform = CropOrPad(target_shape.astype(int))
transform = CropOrPad(target_shape.astype(int), **self.get_init_args())
subject = transform(subject) # type: ignore[assignment]
return subject
3 changes: 2 additions & 1 deletion src/torchio/transforms/preprocessing/spatial/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def apply_transform(self, subject: Subject) -> Subject:
spacing_out,
image_interpolation=self.image_interpolation,
label_interpolation=self.label_interpolation,
**self.get_init_args(),
)
resampled = resample(subject)
assert isinstance(resampled, Subject)
Expand All @@ -72,7 +73,7 @@ def apply_transform(self, subject: Subject) -> Subject:
f' != target shape {tuple(shape_out)}. Fixing with CropOrPad'
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
crop_pad = CropOrPad(shape_out) # type: ignore[arg-type]
crop_pad = CropOrPad(shape_out, **self.get_init_args()) # type: ignore[arg-type]
resampled = crop_pad(resampled)
assert isinstance(resampled, Subject)
return resampled
11 changes: 11 additions & 0 deletions src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ def __repr__(self):
else:
return super().__repr__()

def get_init_args(self) -> dict:
return {
'p': self.probability,
'copy': self.copy,
'include': self.include,
'exclude': self.exclude,
'keep': self.keep,
'parse_input': self.parse_input,
'label_keys': self.label_keys,
}

@property
def name(self):
return self.__class__.__name__
Expand Down

0 comments on commit 980cb9c

Please sign in to comment.