Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KeyError with LabelMap-only subjects #1218

Merged
merged 3 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-sugar",
"tox",
"tox-uv",
"types-Deprecated",
]
doc = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@
self.order = _parse_order(order)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 56 in src/torchio/transforms/augmentation/intensity/random_bias_field.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_bias_field.py#L56

Added line #L56 was not covered by tests

arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
coefficients = self.get_params(
self.order,
self.coefficients_range,
)
for image_name in images_dict:
coefficients = self.get_params(self.order, self.coefficients_range)
arguments['coefficients'][image_name] = coefficients
arguments['order'][image_name] = self.order
transform = BiasField(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed

def get_params(
Expand Down
12 changes: 8 additions & 4 deletions src/torchio/transforms/augmentation/intensity/random_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@
self.std_ranges = self.parse_params(std, None, 'std', min_constraint=0)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 44 in src/torchio/transforms/augmentation/intensity/random_blur.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_blur.py#L44

Added line #L44 was not covered by tests

arguments: Dict[str, dict] = defaultdict(dict)
for name in self.get_images_dict(subject):
std = self.get_params(self.std_ranges)
for name in images_dict:
std = self.get_params(self.std_ranges) # type: ignore[arg-type]
arguments['std'][name] = std
transform = Blur(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed

def get_params(self, std_ranges: TypeSextetFloat) -> TypeTripletFloat:
std = self.sample_uniform_sextet(std_ranges)
return std
sx, sy, sz = self.sample_uniform_sextet(std_ranges)
return sx, sy, sz


class Blur(IntensityTransform):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@
self.log_gamma_range = self._parse_range(log_gamma, 'log_gamma')

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 74 in src/torchio/transforms/augmentation/intensity/random_gamma.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_gamma.py#L74

Added line #L74 was not covered by tests

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in self.get_images_dict(subject).items():
for name, image in images_dict.items():
gammas = [self.get_params(self.log_gamma_range) for _ in image.data]
arguments['gamma'][name] = gammas
transform = Gamma(**self.add_include_exclude(arguments))
Expand Down
11 changes: 8 additions & 3 deletions src/torchio/transforms/augmentation/intensity/random_ghosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@
self.restore = _parse_restore(restore)

def apply_transform(self, subject: Subject) -> Subject:
arguments: Dict[str, dict] = defaultdict(dict)
if any(isinstance(n, str) for n in self.axes):
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 89 in src/torchio/transforms/augmentation/intensity/random_ghosting.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_ghosting.py#L89

Added line #L89 was not covered by tests

if any(isinstance(axis, str) for axis in self.axes):
subject.check_consistent_orientation()
for name, image in self.get_images_dict(subject).items():

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in images_dict.items():
is_2d = image.is_2d()
axes = [a for a in self.axes if a != 2] if is_2d else self.axes
min_ghosts, max_ghosts = self.num_ghosts_range
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@
)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 78 in src/torchio/transforms/augmentation/intensity/random_motion.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_motion.py#L78

Added line #L78 was not covered by tests

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in self.get_images_dict(subject).items():
for name, image in images_dict.items():
params = self.get_params(
self.degrees_range,
self.translation_range,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@
self.std_range = self._parse_range(std, 'std', min_constraint=0)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 49 in src/torchio/transforms/augmentation/intensity/random_noise.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_noise.py#L49

Added line #L49 was not covered by tests

arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
for image_name in images_dict:
mean, std, seed = self.get_params(self.mean_range, self.std_range)
arguments['mean'][image_name] = mean
arguments['std'][image_name] = std
Expand Down
8 changes: 6 additions & 2 deletions src/torchio/transforms/augmentation/intensity/random_spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@
)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 66 in src/torchio/transforms/augmentation/intensity/random_spike.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_spike.py#L66

Added line #L66 was not covered by tests

arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
for image_name in images_dict:
spikes_positions_param, intensity_param = self.get_params(
self.num_spikes_range,
self.intensity_range,
Expand Down Expand Up @@ -90,7 +94,7 @@
r"""Add MRI spike artifacts.

Also known as `Herringbone artifact
<https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
<https://radiopaedia.org/articles/herringbone-artifact>`_,
crisscross artifact or corduroy artifact, it creates stripes in different
directions in image space due to spikes in k-space.

Expand Down
6 changes: 5 additions & 1 deletion src/torchio/transforms/augmentation/intensity/random_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@
return locations # type: ignore[return-value]

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

Check warning on line 94 in src/torchio/transforms/augmentation/intensity/random_swap.py

View check run for this annotation

Codecov / codecov/patch

src/torchio/transforms/augmentation/intensity/random_swap.py#L94

Added line #L94 was not covered by tests

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in self.get_images_dict(subject).items():
for name, image in images_dict.items():
locations = self.get_params(
image.data,
self.patch_size,
Expand Down
7 changes: 5 additions & 2 deletions src/torchio/transforms/augmentation/random_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch

from ...typing import TypeRangeFloat
from ...typing import TypeSextetFloat
from ...typing import TypeTripletFloat
from ..transform import Transform


Expand Down Expand Up @@ -49,8 +51,9 @@ def _get_random_seed() -> int:
"""
return int(torch.randint(0, 2**31, (1,)).item())

def sample_uniform_sextet(self, params):
def sample_uniform_sextet(self, params: TypeSextetFloat) -> TypeTripletFloat:
results = []
for a, b in zip(params[::2], params[1::2]):
results.append(self.sample_uniform(a, b))
return torch.Tensor(results)
sx, sy, sz = results
return sx, sy, sz
21 changes: 15 additions & 6 deletions src/torchio/transforms/augmentation/spatial/random_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,20 @@ def get_params(
translation: TypeSextetFloat,
isotropic: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
scaling_params = self.sample_uniform_sextet(scales)
scaling_params = torch.as_tensor(
self.sample_uniform_sextet(scales),
dtype=torch.float64,
)
if isotropic:
scaling_params.fill_(scaling_params[0])
rotation_params = self.sample_uniform_sextet(degrees)
translation_params = self.sample_uniform_sextet(translation)
rotation_params = torch.as_tensor(
self.sample_uniform_sextet(degrees),
dtype=torch.float64,
)
translation_params = torch.as_tensor(
self.sample_uniform_sextet(translation),
dtype=torch.float64,
)
return scaling_params, rotation_params, translation_params

def apply_transform(self, subject: Subject) -> Subject:
Expand All @@ -166,9 +175,9 @@ def apply_transform(self, subject: Subject) -> Subject:
self.isotropic,
)
arguments = {
'scales': scaling_params.tolist(),
'degrees': rotation_params.tolist(),
'translation': translation_params.tolist(),
'scales': scaling_params,
'degrees': rotation_params,
'translation': translation_params,
'center': self.center,
'default_pad_value': self.default_pad_value,
'image_interpolation': self.image_interpolation,
Expand Down
5 changes: 1 addition & 4 deletions src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ def __init__(
# used to invert invertible transforms
self.args_names: List[str] = []

def __call__(
self,
data: InputType,
) -> InputType:
def __call__(self, data: InputType) -> InputType:
"""Transform data and return a result of the same type.

Args:
Expand Down
Loading