From 3dcae896e1955b4a7068c35e34d3cae1d183633d Mon Sep 17 00:00:00 2001 From: William Date: Mon, 8 Jan 2024 15:11:41 +0100 Subject: [PATCH] Add local version of _load_pretrained_model in mt5 adapter_model --- src/adapters/models/mt5/adapter_model.py | 47 ++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/adapters/models/mt5/adapter_model.py b/src/adapters/models/mt5/adapter_model.py index 58bb236469..f749d69fea 100644 --- a/src/adapters/models/mt5/adapter_model.py +++ b/src/adapters/models/mt5/adapter_model.py @@ -9,6 +9,7 @@ MT5PreTrainedModel, ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward +from transformers.modeling_utils import PreTrainedModel from ...composition import adjust_tensors_for_parallel from ...heads import ( @@ -20,6 +21,7 @@ ) from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init +from ...loading import PredictionHeadLoader logger = logging.getLogger(__name__) @@ -264,3 +266,48 @@ def add_classification_head( else: head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) self.add_prediction_head(head, overwrite_ok) + + # This method is called during model loading in from_pretrained() to apply the state_dict to the model. + # Override it to inject adapter head logic. + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + loaded_keys, + *args, + **kwargs, + ): + # Filter only weights not part of base model + if state_dict is not None: + head_state_dict = { + key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix) + } + else: + head_state_dict = None + head_name = "default" + loader = PredictionHeadLoader(model, error_on_missing=False, convert_to_flex_head=True) + head_config, new_head_state_dict = loader.convert_static_to_flex_head(head_state_dict, load_as=head_name) + + if head_config is not None: + # add head from config + if head_name in model.heads: + logger.warning("Overwriting existing head '{}'".format(head_name)) + + model.add_prediction_head_from_config(head_name, head_config, overwrite_ok=True) + + if new_head_state_dict is not None: + for k in head_state_dict: + del state_dict[k] + loaded_keys.remove(k) + for k in new_head_state_dict: + state_dict[k] = new_head_state_dict[k] + loaded_keys.append(k) + + return PreTrainedModel._load_pretrained_model( + model, + state_dict, + loaded_keys, + *args, + **kwargs, + ) \ No newline at end of file