Skip to content

Commit

Permalink
Add support for Whisper (#693)
Browse files Browse the repository at this point in the history
This PR adds adapter support for the Whisper model from openai and
builds upon work done previously in #572.

Key Additions:

1. Adapter Support for Whisper Model:

- Incorporates adapter functionality to enhance the flexibility and
adaptability of the Whisper model.

2. Enhanced Head Functions:
- Expanded the argument options for some heads by adding a layer
argument with a default value.

3. Preprocessing Scripts for Audio Datasets:

- Added preprocessing scripts tailored for audio datasets.
- These scripts are now utilized in the Whisper tests within the test
suite, replacing the use of arbitrary samples.

---------

Co-authored-by: Leon Engländer <leon.englaender@gmail.com>
Co-authored-by: calpt <calpt@mail.de>
  • Loading branch information
3 people authored Aug 8, 2024
1 parent aea6c09 commit a99e47c
Show file tree
Hide file tree
Showing 42 changed files with 1,637 additions and 48 deletions.
25 changes: 25 additions & 0 deletions docs/classes/models/whisper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Whisper
-----------------------------------------------------------------------------------------------------------------------

The Whisper model was presented in `Robust Speech Recognition via Large-Scale Weak Supervision
<https://arxiv.org/abs/2212.04356>`_ by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine
McLeavey, Ilya Sutskever.

Whisper is a state-of-the-art speech recognition model trained on 680,000 hours of multilingual and multitask data, presented by OpenAI.

The abstract from the paper is the following:

*We study the capabilities of speech processing systems trained simply to predict large amounts of
transcripts of audio on the internet. When scaled to 680,000 hours of multilingual and multitask
supervision, the resulting models generalize well to standard benchmarks and are often competitive
with prior fully supervised results but in a zeroshot transfer setting without the need for any finetuning. When compared to humans, the models
approach their accuracy and robustness. We are releasing models and inference code to serve as
a foundation for further work on robust speech processing.*


WhisperAdapterModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.WhisperAdapterModel
:members:
:inherited-members: WhisperPreTrainedModel
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/roberta
classes/models/t5
classes/models/vit
classes/models/whisper
classes/models/xlmroberta
classes/models/xmod

Expand Down
2 changes: 2 additions & 0 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The table below further shows which model architectures support which adaptation
E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters.
```


| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning | ReFT |
| --------------------------------------- | -| - | - | - | - | - | - |- | - |
| [ALBERT](classes/models/albert.html) ||||||||||
Expand All @@ -33,6 +34,7 @@ The table below further shows which model architectures support which adaptation
| [RoBERTa](classes/models/roberta.html) ||||||||||
| [T5](classes/models/t5.html) |||||||| ||
| [ViT](classes/models/vit.html) ||||||||||
| [Whisper](classes/models/whisper.html) |||||||| ||
| [XLM-RoBERTa](classes/models/xlmroberta.html) ||||||||||
| [X-MOD](classes/models/xmod.html) ||||||||||

Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
"models.roberta": ["RobertaAdapterModel"],
"models.t5": ["T5AdapterModel"],
"models.vit": ["ViTAdapterModel"],
"models.whisper": ["WhisperAdapterModel"],
"models.xlm_roberta": ["XLMRobertaAdapterModel"],
"models.xmod": ["XmodAdapterModel"],
"trainer": ["AdapterTrainer", "Seq2SeqAdapterTrainer"],
Expand Down Expand Up @@ -224,6 +225,7 @@
from .models.roberta import RobertaAdapterModel
from .models.t5 import T5AdapterModel
from .models.vit import ViTAdapterModel
from .models.whisper import WhisperAdapterModel
from .models.xlm_roberta import XLMRobertaAdapterModel
from .models.xmod import XmodAdapterModel
from .trainer import AdapterTrainer, Seq2SeqAdapterTrainer
Expand Down
1 change: 1 addition & 0 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
"llama",
"mistral",
"electra",
"whisper",
"xmod",
],
}
Expand Down
11 changes: 10 additions & 1 deletion src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

logger = logging.getLogger(__name__)


# The "layers" attributes in the configs below map from static head module names to flex head module names.
# In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts).
STATIC_TO_FLEX_HEAD_MAP = {
Expand Down Expand Up @@ -771,6 +770,16 @@
"generator_lm_head",
],
},
"WhisperForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
"layers": 1,
"activation_function": None,
"layer_norm": False,
"bias": False,
},
"layers": ["proj_out"],
},
}


Expand Down
2 changes: 1 addition & 1 deletion src/adapters/heads/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
)
labels = torch.cat((prompt_labels, labels), dim=-1)

loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1))
loss = loss_fct(logits_for_loss.reshape(-1, self.config["vocab_size"]), labels.reshape(-1))

if return_dict:
return self._create_model_output(loss, lm_logits, outputs)
Expand Down
15 changes: 9 additions & 6 deletions src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

logger = logging.getLogger(__name__)


MODEL_HEAD_MAP = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
Expand Down Expand Up @@ -440,47 +439,51 @@ def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=Fals
self.add_prediction_head(head, overwrite_ok)

@head_type("masked_lm")
def add_masked_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False):
def add_masked_lm_head(self, head_name, activation_function="gelu", layers=2, overwrite_ok=False):
"""
Adds a masked language modeling head on top of the model.
Args:
head_name (str): The name of the head.
activation_function (str, optional): Activation function. Defaults to 'gelu'.
layers (int, optional): Number of layers. Defaults to 2.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = BertStyleMaskedLMHead(self, head_name, activation_function=activation_function)
head = BertStyleMaskedLMHead(self, head_name, layers=layers, activation_function=activation_function)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)

@head_type("causal_lm")
def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False):
def add_causal_lm_head(self, head_name, activation_function="gelu", layers=2, overwrite_ok=False):
"""
Adds a causal language modeling head on top of the model.
Args:
head_name (str): The name of the head.
activation_function (str, optional): Activation function. Defaults to 'gelu'.
layers (int, optional): Number of layers. Defaults to 2.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = CausalLMHead(
self, head_name, layers=2, activation_function=activation_function, layer_norm=True, bias=True
self, head_name, layers=layers, activation_function=activation_function, layer_norm=True, bias=True
)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)

@head_type("seq2seq_lm")
def add_seq2seq_lm_head(
self,
head_name,
layers=1,
overwrite_ok=False,
):
"""
Adds a sequence-to-sequence language modeling head on top of the model.
Args:
head_name (str): The name of the head.
layers (int, optional): Number of layers. Defaults to 1.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = Seq2SeqLMHead(self, head_name)
head = Seq2SeqLMHead(self, head_name, layers=layers)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)

def delete_head(self, head_name: str):
Expand Down
9 changes: 8 additions & 1 deletion src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,14 @@ def forward(self, *args, **kwargs):
prefix_states = {}
if adapter_setup is not None:
# Infer batch size
input_tensor_names = ["input_ids", "decoder_input_ids", "attention_mask", "inputs_embeds", "pixel_values"]
input_tensor_names = [
"input_ids",
"decoder_input_ids",
"attention_mask",
"inputs_embeds",
"pixel_values",
"input_features",
]
batch_size = None
for name in input_tensor_names:
if kwargs.get(name, None) is not None:
Expand Down
14 changes: 12 additions & 2 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,19 @@ def __init__(self, in_features: int, config: ReftConfig):

def _gather_adapted_states(self, hidden_states: torch.Tensor):
context = ForwardContext.get_context()
bsz, _, ddim = hidden_states.size()
bsz, seq_len, ddim = hidden_states.size()

# if cached indexing matrices are computed for different hidden_states size -> recompute
cache_invalidated = False
if hasattr(context, "pref_idx") and hasattr(context, "suff_idx"):
cache_invalidated = (
torch.max(context.suff_idx) >= seq_len # indices out of bounds
or bsz != context.suff_idx.size(0) # batch size mismatch
or ddim != context.suff_idx.size(2) # hidden size mismatch
)

# no cached indexing matrices available -> compute now
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx"):
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx") or cache_invalidated:
# read offsets & lengths from context
if hasattr(context, "seqlens"):
first_non_padding = context.offsets
Expand Down
13 changes: 12 additions & 1 deletion src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,18 @@ def _prepare_model_inputs(self, *args, **kwargs):
and self.adapters_config.active_setup
and self.adapters_config.active_setup.parallel_channels > 1
):
input_ids = input_ids.repeat(self.adapters_config.active_setup.parallel_channels, 1)
# Extract original shape
input_shape = input_ids.shape
# Replicate input_ids to match the number of parallel channels
# Also works for inputs with more than 2 dimensions
repeat_shape = [
self.adapters_config.active_setup.parallel_channels
] + [ # first dimension is parallel channels
1
] * (
len(input_shape) - 1
) # residual dims should be replicated parallel_channels times
input_ids = input_ids.repeat(repeat_shape)
model_kwargs["adapter_input_parallelized"] = True

return input_ids, input_name, model_kwargs
Expand Down
12 changes: 12 additions & 0 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
T5ModelAdaptersMixin,
)
from .vit.mixin_vit import ViTIntermediateAdaptersMixin, ViTModelAdaptersMixin
from .whisper.mixin_whisper import (
WhisperDecoderAdaptersMixin,
WhisperDecoderWrapperAdaptersMixin,
WhisperEncoderAdaptersMixin,
WhisperForAudioClassificationWithHeadsMixin,
WhisperModelAdaptersMixin,
)
from .xmod.mixin_xmod import XmodModelAdaptersMixin


Expand Down Expand Up @@ -95,6 +102,11 @@
"BertGenerationEncoder": BertModelAdaptersMixin,
"BertGenerationLayer": BertLayerAdaptersMixin,
"LlamaModel": LlamaModelAdapterMixin,
"WhisperEncoder": WhisperEncoderAdaptersMixin,
"WhisperDecoder": WhisperDecoderAdaptersMixin,
"WhisperModel": WhisperModelAdaptersMixin,
"WhisperDecoderWrapper": WhisperDecoderWrapperAdaptersMixin,
"WhisperForAudioClassification": WhisperForAudioClassificationWithHeadsMixin,
"LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin,
"MistralModel": MistralModelAdapterMixin,
}
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
("roberta", "RobertaAdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
("whisper", "WhisperAdapterModel"),
("xlm-roberta", "XLMRobertaAdapterModel"),
("xmod", "XmodAdapterModel"),
]
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward(
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
Expand Down
39 changes: 39 additions & 0 deletions src/adapters/models/whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["WhisperAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import WhisperAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
Loading

0 comments on commit a99e47c

Please sign in to comment.