diff --git a/src/torchio/transforms/augmentation/composition.py b/src/torchio/transforms/augmentation/composition.py index 2e8b7b3c..c09aa856 100644 --- a/src/torchio/transforms/augmentation/composition.py +++ b/src/torchio/transforms/augmentation/composition.py @@ -2,6 +2,8 @@ import warnings from collections.abc import Sequence +from typing import Any +from typing import TypeAlias from typing import Union import numpy as np @@ -11,7 +13,9 @@ from ..transform import Transform from . import RandomTransform -TypeTransformsDict = Union[dict[Transform, float], Sequence[Transform]] +TypeTransformsDict: TypeAlias = Union[dict[Transform, float], Sequence[Transform]] +HydraConfig: TypeAlias = dict[str, Any] +HydraConfigDict: TypeAlias = dict[str, HydraConfig] class Compose(Transform): @@ -81,6 +85,15 @@ def inverse(self, warn: bool = True) -> Compose: ) return result + def to_hydra_config(self) -> HydraConfig: + """Return a dictionary representation of the transform for Hydra instantiation.""" + transform_dict: HydraConfig = {'_target_': self._get_name_with_module()} + transform_dict['transforms'] = [] + transform_dict.update(self._get_reproducing_arguments()) + for transform in self.transforms: + transform_dict['transforms'].append(transform.to_hydra_config()) + return self._tuples_to_lists(transform_dict) + class OneOf(RandomTransform): """Apply only one of the given transforms. diff --git a/src/torchio/transforms/transform.py b/src/torchio/transforms/transform.py index 2375e7ce..26735d09 100644 --- a/src/torchio/transforms/transform.py +++ b/src/torchio/transforms/transform.py @@ -596,3 +596,22 @@ def get_mask_from_bounds( mask = torch.zeros_like(tensor, dtype=torch.bool) mask[:, i0:i1, j0:j1, k0:k1] = True return mask + + def _get_name_with_module(self) -> str: + """Return the name of the transform including its module.""" + return f'{self.__class__.__module__}.{self.__class__.__name__}' + + @staticmethod + def _tuples_to_lists(obj): + if isinstance(obj, (tuple, list)): + return [Transform._tuples_to_lists(x) for x in obj] + if isinstance(obj, dict): + return {k: Transform._tuples_to_lists(v) for k, v in obj.items()} + return obj + + def to_hydra_config(self) -> dict: + """Return a dictionary representation of the transform for Hydra instantiation.""" + target = self._get_name_with_module() + transform_dict = {'_target_': target} + transform_dict.update(self._get_reproducing_arguments()) + return self._tuples_to_lists(transform_dict) diff --git a/src/torchio/visualization.py b/src/torchio/visualization.py index f177862a..dc260c8e 100644 --- a/src/torchio/visualization.py +++ b/src/torchio/visualization.py @@ -103,6 +103,14 @@ def plot_volume( elif rgb and image.num_channels == 3: data = image.data # keep image as it is elif channel is None: + if image.num_channels > 1: + message = ( + 'Multiple channels found in the image. ' + 'Plotting the first channel (0). ' + 'To plot a different channel, please specify the channel ' + 'index using the "channel" argument.' + ) + warnings.warn(message, RuntimeWarning, stacklevel=2) data = image.data[0:1] # just use the first channel else: data = image.data[np.newaxis, channel]