Skip to content

Commit

Permalink
Add local version of _load_pretrained_model in mt5 adapter_model
Browse files Browse the repository at this point in the history
  • Loading branch information
William committed Jan 8, 2024
1 parent 2feb7f6 commit 3dcae89
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions src/adapters/models/mt5/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MT5PreTrainedModel,
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_utils import PreTrainedModel

from ...composition import adjust_tensors_for_parallel
from ...heads import (
Expand All @@ -20,6 +21,7 @@
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
from ...loading import PredictionHeadLoader


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -264,3 +266,48 @@ def add_classification_head(
else:
head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label)
self.add_prediction_head(head, overwrite_ok)

# This method is called during model loading in from_pretrained() to apply the state_dict to the model.
# Override it to inject adapter head logic.
@classmethod
def _load_pretrained_model(
cls,
model,
state_dict,
loaded_keys,
*args,
**kwargs,
):
# Filter only weights not part of base model
if state_dict is not None:
head_state_dict = {
key: value for key, value in state_dict.items() if not key.startswith(cls.base_model_prefix)
}
else:
head_state_dict = None
head_name = "default"
loader = PredictionHeadLoader(model, error_on_missing=False, convert_to_flex_head=True)
head_config, new_head_state_dict = loader.convert_static_to_flex_head(head_state_dict, load_as=head_name)

if head_config is not None:
# add head from config
if head_name in model.heads:
logger.warning("Overwriting existing head '{}'".format(head_name))

model.add_prediction_head_from_config(head_name, head_config, overwrite_ok=True)

if new_head_state_dict is not None:
for k in head_state_dict:
del state_dict[k]
loaded_keys.remove(k)
for k in new_head_state_dict:
state_dict[k] = new_head_state_dict[k]
loaded_keys.append(k)

return PreTrainedModel._load_pretrained_model(
model,
state_dict,
loaded_keys,
*args,
**kwargs,
)

0 comments on commit 3dcae89

Please sign in to comment.