diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 46e1c09dba1fe7..bf06d9c4053822 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1006,32 +1006,6 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) -class BackboneMixin: - @property - def out_feature_channels(self): - # the current backbones will output the number of channels for each stage - # even if that stage is not in the out_features list. - return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)} - - @property - def channels(self): - return [self.out_feature_channels[name] for name in self.out_features] - - def forward_with_filtered_kwargs(self, *args, **kwargs): - signature = dict(inspect.signature(self.forward).parameters) - filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature} - return self(*args, **filtered_kwargs) - - def forward( - self, - pixel_values: Tensor, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - raise NotImplementedError("This method should be implemented by the derived class.") - - class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): r""" Base class for all models. diff --git a/src/transformers/models/bit/configuration_bit.py b/src/transformers/models/bit/configuration_bit.py index da53807f3f0686..bfac3ab03f0024 100644 --- a/src/transformers/models/bit/configuration_bit.py +++ b/src/transformers/models/bit/configuration_bit.py @@ -16,6 +16,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -25,7 +26,7 @@ } -class BitConfig(PretrainedConfig): +class BitConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`BitModel`]. It is used to instantiate an BiT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the @@ -128,35 +129,6 @@ def __init__( self.width_factor = width_factor self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 6a63a7a3e29c69..d440f180757ba9 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -31,7 +31,7 @@ BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -39,6 +39,7 @@ logging, replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_bit import BitConfig @@ -848,12 +849,10 @@ def __init__(self, config): self.stage_names = config.stage_names self.bit = BitModel(config) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.num_features = [config.embedding_size] + config.hidden_sizes - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) # initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/convnext/configuration_convnext.py b/src/transformers/models/convnext/configuration_convnext.py index a4b7272295c72b..0cba7804057906 100644 --- a/src/transformers/models/convnext/configuration_convnext.py +++ b/src/transformers/models/convnext/configuration_convnext.py @@ -22,6 +22,7 @@ from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -32,7 +33,7 @@ } -class ConvNextConfig(PretrainedConfig): +class ConvNextConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration @@ -119,38 +120,9 @@ def __init__( self.drop_path_rate = drop_path_rate self.image_size = image_size self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) class ConvNextOnnxConfig(OnnxConfig): diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index 35302bffed66b5..1748e68aeec154 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -29,7 +29,7 @@ BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -37,6 +37,7 @@ logging, replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_convnext import ConvNextConfig @@ -485,16 +486,14 @@ def __init__(self, config): self.embeddings = ConvNextEmbeddings(config) self.encoder = ConvNextEncoder(config) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) # Add layer norms to hidden states of out_features hidden_states_norms = {} - for stage, num_channels in zip(self.out_features, self.channels): + for stage, num_channels in zip(self._out_features, self.channels): hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first") self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) diff --git a/src/transformers/models/convnextv2/configuration_convnextv2.py b/src/transformers/models/convnextv2/configuration_convnextv2.py index f02a21371b20d5..14dfcf85124e7f 100644 --- a/src/transformers/models/convnextv2/configuration_convnextv2.py +++ b/src/transformers/models/convnextv2/configuration_convnextv2.py @@ -17,6 +17,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -26,7 +27,7 @@ } -class ConvNextV2Config(PretrainedConfig): +class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ConvNextV2Model`]. It is used to instantiate an ConvNeXTV2 model according to the specified arguments, defining the model architecture. Instantiating a @@ -109,35 +110,6 @@ def __init__( self.drop_path_rate = drop_path_rate self.image_size = image_size self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index c309cdc3b6e243..c4cac4eb39fc68 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -29,7 +29,7 @@ BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -37,6 +37,7 @@ logging, replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_convnextv2 import ConvNextV2Config @@ -508,16 +509,14 @@ def __init__(self, config): self.embeddings = ConvNextV2Embeddings(config) self.encoder = ConvNextV2Encoder(config) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) # Add layer norms to hidden states of out_features hidden_states_norms = {} - for stage, num_channels in zip(self.out_features, self.channels): + for stage, num_channels in zip(self._out_features, self.channels): hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first") self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) diff --git a/src/transformers/models/dinat/configuration_dinat.py b/src/transformers/models/dinat/configuration_dinat.py index 7c6a84ecdd13f8..963c72f29bd407 100644 --- a/src/transformers/models/dinat/configuration_dinat.py +++ b/src/transformers/models/dinat/configuration_dinat.py @@ -16,6 +16,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -26,7 +27,7 @@ } -class DinatConfig(PretrainedConfig): +class DinatConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat model according to the specified arguments, defining the model architecture. Instantiating a configuration with the @@ -145,35 +146,6 @@ def __init__( self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.layer_scale_init_value = layer_scale_init_value self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 5b2394f122f1ae..7e3809c1a3033d 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -26,7 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -39,6 +39,7 @@ replace_return_docstrings, requires_backends, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_dinat import DinatConfig @@ -890,16 +891,14 @@ def __init__(self, config): self.embeddings = DinatEmbeddings(config) self.encoder = DinatEncoder(config) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] # Add layer norms to hidden states of out_features hidden_states_norms = {} - for stage, num_channels in zip(self.out_features, self.channels): + for stage, num_channels in zip(self._out_features, self.channels): hidden_states_norms[stage] = nn.LayerNorm(num_channels) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) diff --git a/src/transformers/models/focalnet/configuration_focalnet.py b/src/transformers/models/focalnet/configuration_focalnet.py index c6814e1dda14e7..f4bcd0ddce3bf9 100644 --- a/src/transformers/models/focalnet/configuration_focalnet.py +++ b/src/transformers/models/focalnet/configuration_focalnet.py @@ -16,6 +16,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -25,7 +26,7 @@ } -class FocalNetConfig(PretrainedConfig): +class FocalNetConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`FocalNetModel`]. It is used to instantiate a FocalNet model according to the specified arguments, defining the model architecture. Instantiating a configuration @@ -156,35 +157,6 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.encoder_stride = encoder_stride self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index cfd64689763bf6..e7ebdda5e5d426 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -36,6 +36,7 @@ logging, replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_focalnet import FocalNetConfig @@ -987,11 +988,9 @@ def __init__(self, config): self.focalnet = FocalNetModel(config) self.num_features = [config.embed_dim] + config.hidden_sizes - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) # initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/maskformer/configuration_maskformer_swin.py b/src/transformers/models/maskformer/configuration_maskformer_swin.py index ca60b6176eedca..7c3ac54bd80d23 100644 --- a/src/transformers/models/maskformer/configuration_maskformer_swin.py +++ b/src/transformers/models/maskformer/configuration_maskformer_swin.py @@ -16,12 +16,13 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) -class MaskFormerSwinConfig(PretrainedConfig): +class MaskFormerSwinConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration @@ -141,35 +142,6 @@ def __init__( # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 8684a58a4e7b97..c7b74a6f2bd7cc 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -27,8 +27,9 @@ from ...activations import ACT2FN from ...file_utils import ModelOutput from ...modeling_outputs import BackboneOutput -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -855,14 +856,13 @@ def __init__(self, config: MaskFormerSwinConfig): self.stage_names = config.stage_names self.model = MaskFormerSwinModel(config) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] + self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] if "stem" in self.out_features: raise ValueError("This backbone does not support 'stem' in the `out_features`.") - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.hidden_states_norms = nn.ModuleList( [nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]] diff --git a/src/transformers/models/nat/configuration_nat.py b/src/transformers/models/nat/configuration_nat.py index a74b8c9165c7c4..5d8bd6b3c6eb0e 100644 --- a/src/transformers/models/nat/configuration_nat.py +++ b/src/transformers/models/nat/configuration_nat.py @@ -16,6 +16,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -26,7 +27,7 @@ } -class NatConfig(PretrainedConfig): +class NatConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model according to the specified arguments, defining the model architecture. Instantiating a configuration with the @@ -141,35 +142,6 @@ def __init__( self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.layer_scale_init_value = layer_scale_init_value self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index c5e83b29da6769..7634a08ad95bb7 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -26,7 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, @@ -39,6 +39,7 @@ replace_return_docstrings, requires_backends, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_nat import NatConfig @@ -868,11 +869,9 @@ def __init__(self, config): self.embeddings = NatEmbeddings(config) self.encoder = NatEncoder(config) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] # Add layer norms to hidden states of out_features diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 6a88935f3b02f9..f12fe542a06735 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -22,6 +22,7 @@ from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -31,7 +32,7 @@ } -class ResNetConfig(PretrainedConfig): +class ResNetConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration @@ -108,38 +109,9 @@ def __init__( self.hidden_act = hidden_act self.downsample_in_first_stage = downsample_in_first_stage self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) class ResNetOnnxConfig(OnnxConfig): diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 6926f6a43116cc..b177cdeda6c1ce 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -28,7 +28,7 @@ BaseModelOutputWithPoolingAndNoAttention, ImageClassifierOutputWithNoAttention, ) -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -36,6 +36,7 @@ logging, replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_resnet import ResNetConfig @@ -436,11 +437,9 @@ def __init__(self, config): self.embedder = ResNetEmbeddings(config) self.encoder = ResNetEncoder(config) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) self.num_features = [config.embedding_size] + config.hidden_sizes # initialize weights and apply final processing diff --git a/src/transformers/models/swin/configuration_swin.py b/src/transformers/models/swin/configuration_swin.py index 612bc9949fb216..757112f8cebf49 100644 --- a/src/transformers/models/swin/configuration_swin.py +++ b/src/transformers/models/swin/configuration_swin.py @@ -22,6 +22,7 @@ from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -34,7 +35,7 @@ } -class SwinConfig(PretrainedConfig): +class SwinConfig(BackboneConfigMixin, PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SwinModel`]. It is used to instantiate a Swin model according to the specified arguments, defining the model architecture. Instantiating a configuration with the @@ -158,38 +159,9 @@ def __init__( # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] - - if out_features is not None and out_indices is not None: - if len(out_features) != len(out_indices): - raise ValueError("out_features and out_indices should have the same length if both are set") - elif out_features != [self.stage_names[idx] for idx in out_indices]: - raise ValueError("out_features and out_indices should correspond to the same stages if both are set") - - if out_features is None and out_indices is not None: - out_features = [self.stage_names[idx] for idx in out_indices] - elif out_features is not None and out_indices is None: - out_indices = [self.stage_names.index(feature) for feature in out_features] - elif out_features is None and out_indices is None: - out_features = [self.stage_names[-1]] - out_indices = [len(self.stage_names) - 1] - - if out_features is not None: - if not isinstance(out_features, list): - raise ValueError("out_features should be a list") - for feature in out_features: - if feature not in self.stage_names: - raise ValueError( - f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" - ) - if out_indices is not None: - if not isinstance(out_indices, (list, tuple)): - raise ValueError("out_indices should be a list or tuple") - for idx in out_indices: - if idx >= len(self.stage_names): - raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}") - - self.out_features = out_features - self.out_indices = out_indices + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) class SwinOnnxConfig(OnnxConfig): diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 2f7cfeb1adbde9..6482ff1b5bf206 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ( ModelOutput, @@ -38,6 +38,7 @@ logging, replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices from .configuration_swin import SwinConfig @@ -1264,16 +1265,14 @@ def __init__(self, config: SwinConfig): self.embeddings = SwinEmbeddings(config) self.encoder = SwinEncoder(config, self.embeddings.patch_grid) - self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] - if config.out_indices is not None: - self.out_indices = config.out_indices - else: - self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features) + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + config.out_features, config.out_indices, self.stage_names + ) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] # Add layer norms to hidden states of out_features hidden_states_norms = {} - for stage, num_channels in zip(self.out_features, self.channels): + for stage, num_channels in zip(self._out_features, self.channels): hidden_states_norms[stage] = nn.LayerNorm(num_channels) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index a00866f77dc447..6143c57e92d598 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -22,8 +22,9 @@ from ... import AutoBackbone from ...modeling_outputs import SemanticSegmenterOutput -from ...modeling_utils import BackboneMixin, PreTrainedModel +from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...utils.backbone_utils import BackboneMixin from .configuration_upernet import UperNetConfig diff --git a/src/transformers/utils/backbone_utils.py b/src/transformers/utils/backbone_utils.py new file mode 100644 index 00000000000000..8c6b7107eb0e60 --- /dev/null +++ b/src/transformers/utils/backbone_utils.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Collection of utils to be used by backbones and their components.""" + +import inspect +from typing import Iterable, List, Optional, Tuple, Union + + +def verify_out_features_out_indices( + out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]] +): + """ + Verify that out_indices and out_features are valid for the given stage_names. + """ + if stage_names is None: + raise ValueError("Stage_names must be set for transformers backbones") + + if out_features is not None: + if not isinstance(out_features, (list,)): + raise ValueError(f"out_features must be a list {type(out_features)}") + if any(feat not in stage_names for feat in out_features): + raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}") + + if out_indices is not None: + if not isinstance(out_indices, (list, tuple)): + raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}") + if any(idx >= len(stage_names) for idx in out_indices): + raise ValueError("out_indices must be valid indices for stage_names {stage_names}, got {out_indices}") + + if out_features is not None and out_indices is not None: + if len(out_features) != len(out_indices): + raise ValueError("out_features and out_indices should have the same length if both are set") + if out_features != [stage_names[idx] for idx in out_indices]: + raise ValueError("out_features and out_indices should correspond to the same stages if both are set") + + +def _align_output_features_output_indices( + out_features: Optional[List[str]], + out_indices: Optional[Union[List[int], Tuple[int]]], + stage_names: List[str], +): + """ + Finds the corresponding `out_features` and `out_indices` for the given `stage_names`. + + The logic is as follows: + - `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the + `out_indices`. + - `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the + `out_features`. + - `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage. + - `out_indices` and `out_features` set: input `out_indices` and `out_features` are returned. + + Args: + out_features (`List[str]`): The names of the features for the backbone to output. + out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output. + stage_names (`List[str]`): The names of the stages of the backbone. + """ + if out_indices is None and out_features is None: + out_indices = [len(stage_names) - 1] + out_features = [stage_names[-1]] + elif out_indices is None and out_features is not None: + out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features] + elif out_features is None and out_indices is not None: + out_features = [stage_names[idx] for idx in out_indices] + return out_features, out_indices + + +def get_aligned_output_features_output_indices( + out_features: Optional[List[str]], + out_indices: Optional[Union[List[int], Tuple[int]]], + stage_names: List[str], +) -> Tuple[List[str], List[int]]: + """ + Get the `out_features` and `out_indices` so that they are aligned. + + The logic is as follows: + - `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the + `out_indices`. + - `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the + `out_features`. + - `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage. + - `out_indices` and `out_features` set: they are verified to be aligned. + + Args: + out_features (`List[str]`): The names of the features for the backbone to output. + out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output. + stage_names (`List[str]`): The names of the stages of the backbone. + """ + # First verify that the out_features and out_indices are valid + verify_out_features_out_indices(out_features=out_features, out_indices=out_indices, stage_names=stage_names) + output_features, output_indices = _align_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=stage_names + ) + # Verify that the aligned out_features and out_indices are valid + verify_out_features_out_indices(out_features=output_features, out_indices=output_indices, stage_names=stage_names) + return output_features, output_indices + + +class BackboneMixin: + @property + def out_feature_channels(self): + # the current backbones will output the number of channels for each stage + # even if that stage is not in the out_features list. + return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)} + + @property + def channels(self): + return [self.out_feature_channels[name] for name in self.out_features] + + def forward_with_filtered_kwargs(self, *args, **kwargs): + signature = dict(inspect.signature(self.forward).parameters) + filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature} + return self(*args, **filtered_kwargs) + + def forward( + self, + pixel_values, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + raise NotImplementedError("This method should be implemented by the derived class.") + + @property + def out_features(self): + return self._out_features + + @out_features.setter + def out_features(self, out_features: List[str]): + """ + Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. + """ + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=None, stage_names=self.stage_names + ) + + @property + def out_indices(self): + return self._out_indices + + @out_indices.setter + def out_indices(self, out_indices: Union[Tuple[int], List[int]]): + """ + Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. + """ + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=None, out_indices=out_indices, stage_names=self.stage_names + ) + + +class BackboneConfigMixin: + """ + A Mixin to support handling the `out_features` and `out_indices` attributes for the backbone configurations. + """ + + @property + def out_features(self): + return self._out_features + + @out_features.setter + def out_features(self, out_features: List[str]): + """ + Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. + """ + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=None, stage_names=self.stage_names + ) + + @property + def out_indices(self): + return self._out_indices + + @out_indices.setter + def out_indices(self, out_indices: Union[Tuple[int], List[int]]): + """ + Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. + """ + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=None, out_indices=out_indices, stage_names=self.stage_names + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to + include the `out_features` and `out_indices` attributes. + """ + output = super().to_dict() + output["out_features"] = output.pop("_out_features") + output["out_indices"] = output.pop("_out_indices") + return output diff --git a/tests/test_backbone_common.py b/tests/test_backbone_common.py index 80e68a2f44ad35..fd9bbe3bfbfe7c 100644 --- a/tests/test_backbone_common.py +++ b/tests/test_backbone_common.py @@ -81,9 +81,15 @@ def test_channels(self): out_channels = [num_features[idx] for idx in out_indices] self.assertListEqual(model.channels, out_channels) - config.out_features = None - config.out_indices = None - model = model_class(config) + new_config = copy.deepcopy(config) + new_config.out_features = None + model = model_class(new_config) + self.assertEqual(len(model.channels), 1) + self.assertListEqual(model.channels, [num_features[-1]]) + + new_config = copy.deepcopy(config) + new_config.out_indices = None + model = model_class(new_config) self.assertEqual(len(model.channels), 1) self.assertListEqual(model.channels, [num_features[-1]]) @@ -102,6 +108,15 @@ def test_create_from_modified_config(self): # Check output of last stage is taken if out_features=None, out_indices=None modified_config = copy.deepcopy(config) modified_config.out_features = None + model = model_class(modified_config) + model.to(torch_device) + model.eval() + result = model(**inputs_dict) + + self.assertEqual(len(result.feature_maps), 1) + self.assertEqual(len(model.channels), 1) + + modified_config = copy.deepcopy(config) modified_config.out_indices = None model = model_class(modified_config) model.to(torch_device) diff --git a/tests/utils/test_backbone_utils.py b/tests/utils/test_backbone_utils.py new file mode 100644 index 00000000000000..66b7087da2463b --- /dev/null +++ b/tests/utils/test_backbone_utils.py @@ -0,0 +1,102 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.utils.backbone_utils import ( + BackboneMixin, + get_aligned_output_features_output_indices, + verify_out_features_out_indices, +) + + +class BackboneUtilsTester(unittest.TestCase): + def test_get_aligned_output_features_output_indices(self): + stage_names = ["a", "b", "c"] + + # Defaults to last layer if both are None + out_features, out_indices = get_aligned_output_features_output_indices(None, None, stage_names) + self.assertEqual(out_features, ["c"]) + self.assertEqual(out_indices, [2]) + + # Out indices set to match out features + out_features, out_indices = get_aligned_output_features_output_indices(["a", "c"], None, stage_names) + self.assertEqual(out_features, ["a", "c"]) + self.assertEqual(out_indices, [0, 2]) + + # Out features set to match out indices + out_features, out_indices = get_aligned_output_features_output_indices(None, [0, 2], stage_names) + self.assertEqual(out_features, ["a", "c"]) + self.assertEqual(out_indices, [0, 2]) + + # Out features selected from negative indices + out_features, out_indices = get_aligned_output_features_output_indices(None, [-3, -1], stage_names) + self.assertEqual(out_features, ["a", "c"]) + self.assertEqual(out_indices, [-3, -1]) + + def test_verify_out_features_out_indices(self): + # Stage names must be set + with self.assertRaises(ValueError): + verify_out_features_out_indices(["a", "b"], (0, 1), None) + + # Out features must be a list + with self.assertRaises(ValueError): + verify_out_features_out_indices(("a", "b"), (0, 1), ["a", "b"]) + + # Out features must be a subset of stage names + with self.assertRaises(ValueError): + verify_out_features_out_indices(["a", "b"], (0, 1), ["a"]) + + # Out indices must be a list or tuple + with self.assertRaises(ValueError): + verify_out_features_out_indices(None, 0, ["a", "b"]) + + # Out indices must be a subset of stage names + with self.assertRaises(ValueError): + verify_out_features_out_indices(None, (0, 1), ["a"]) + + # Out features and out indices must be the same length + with self.assertRaises(ValueError): + verify_out_features_out_indices(["a", "b"], (0,), ["a", "b", "c"]) + + # Out features should match out indices + with self.assertRaises(ValueError): + verify_out_features_out_indices(["a", "b"], (0, 2), ["a", "b", "c"]) + + # Out features and out indices should be in order + with self.assertRaises(ValueError): + verify_out_features_out_indices(["b", "a"], (0, 1), ["a", "b"]) + + # Check passes with valid inputs + verify_out_features_out_indices(["a", "b", "d"], (0, 1, -1), ["a", "b", "c", "d"]) + + def test_backbone_mixin(self): + backbone = BackboneMixin() + + backbone.stage_names = ["a", "b", "c"] + backbone._out_features = ["a", "c"] + backbone._out_indices = [0, 2] + + # Check that the output features and indices are set correctly + self.assertEqual(backbone.out_features, ["a", "c"]) + self.assertEqual(backbone.out_indices, [0, 2]) + + # Check out features and indices are updated correctly + backbone.out_features = ["a", "b"] + self.assertEqual(backbone.out_features, ["a", "b"]) + self.assertEqual(backbone.out_indices, [0, 1]) + + backbone.out_indices = [-3, -1] + self.assertEqual(backbone.out_features, ["a", "c"]) + self.assertEqual(backbone.out_indices, [-3, -1])