Skip to content

Commit

Permalink
Add methods to update and verify out_features out_indices (huggingfac…
Browse files Browse the repository at this point in the history
…e#23031)

* Add methods to update and verify out_features out_indices

* Safe update for config attributes

* Fix function names

* Save config correctly

* PR comments - use property setters

* PR comment - directly set attributes

* Update test

* Add updates to recently merged focalnet backbone
  • Loading branch information
amyeroberts authored May 4, 2023
1 parent 78b7deb commit 90e8263
Show file tree
Hide file tree
Showing 23 changed files with 420 additions and 385 deletions.
26 changes: 0 additions & 26 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,32 +1006,6 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class BackboneMixin:
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}

@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]

def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)

def forward(
self,
pixel_values: Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
Expand Down
38 changes: 5 additions & 33 deletions src/transformers/models/bit/configuration_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices


logger = logging.get_logger(__name__)
Expand All @@ -25,7 +26,7 @@
}


class BitConfig(PretrainedConfig):
class BitConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BitModel`]. It is used to instantiate an BiT
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
Expand Down Expand Up @@ -128,35 +129,6 @@ def __init__(
self.width_factor = width_factor

self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]

if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")

if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]

if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")

self.out_features = out_features
self.out_indices = out_indices
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
11 changes: 5 additions & 6 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_bit import BitConfig


Expand Down Expand Up @@ -848,12 +849,10 @@ def __init__(self, config):
self.stage_names = config.stage_names
self.bit = BitModel(config)

self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embedding_size] + config.hidden_sizes
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)

# initialize weights and apply final processing
self.post_init()
Expand Down
38 changes: 5 additions & 33 deletions src/transformers/models/convnext/configuration_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices


logger = logging.get_logger(__name__)
Expand All @@ -32,7 +33,7 @@
}


class ConvNextConfig(PretrainedConfig):
class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
Expand Down Expand Up @@ -119,38 +120,9 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.image_size = image_size
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]

if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")

if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]

if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")

self.out_features = out_features
self.out_indices = out_indices
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)


class ConvNextOnnxConfig(OnnxConfig):
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/convnext/modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_convnext import ConvNextConfig


Expand Down Expand Up @@ -485,16 +486,14 @@ def __init__(self, config):
self.embeddings = ConvNextEmbeddings(config)
self.encoder = ConvNextEncoder(config)

self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)

# Add layer norms to hidden states of out_features
hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels):
for stage, num_channels in zip(self._out_features, self.channels):
hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)

Expand Down
38 changes: 5 additions & 33 deletions src/transformers/models/convnextv2/configuration_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices


logger = logging.get_logger(__name__)
Expand All @@ -26,7 +27,7 @@
}


class ConvNextV2Config(PretrainedConfig):
class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ConvNextV2Model`]. It is used to instantiate an
ConvNeXTV2 model according to the specified arguments, defining the model architecture. Instantiating a
Expand Down Expand Up @@ -109,35 +110,6 @@ def __init__(
self.drop_path_rate = drop_path_rate
self.image_size = image_size
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]

if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")

if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]

if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")

self.out_features = out_features
self.out_indices = out_indices
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
13 changes: 6 additions & 7 deletions src/transformers/models/convnextv2/modeling_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_convnextv2 import ConvNextV2Config


Expand Down Expand Up @@ -508,16 +509,14 @@ def __init__(self, config):
self.embeddings = ConvNextV2Embeddings(config)
self.encoder = ConvNextV2Encoder(config)

self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)

# Add layer norms to hidden states of out_features
hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels):
for stage, num_channels in zip(self._out_features, self.channels):
hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first")
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)

Expand Down
38 changes: 5 additions & 33 deletions src/transformers/models/dinat/configuration_dinat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices


logger = logging.get_logger(__name__)
Expand All @@ -26,7 +27,7 @@
}


class DinatConfig(PretrainedConfig):
class DinatConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
Expand Down Expand Up @@ -145,35 +146,6 @@ def __init__(
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]

if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")

if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]

if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")

self.out_features = out_features
self.out_indices = out_indices
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
Loading

0 comments on commit 90e8263

Please sign in to comment.