Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added Adapters to Whisper Model from openai #572

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions adapter_docs/classes/models/whisper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Whisper
-----------------------------------------------------------------------------------------------------------------------

The Whisper model was presented in `Robust Speech Recognition via Large-Scale Weak Supervision
<https://arxiv.org/abs/2212.04356>`_ by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine
McLeavey, Ilya Sutskever.

According to the abstract, Whisper is trained on 680,000 hours of multilingual and multitask data. This
scale was previously unseen. Whisper is able to approach the accuracy and robustness of humans.


WhisperAdapterModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.adapters.WhisperAdapterModel
:members:
:inherited-members: WhisperPreTrainedModel
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/gpt2
classes/models/gptj
classes/models/mbart
classes/models/whisper
classes/models/roberta
classes/models/t5
classes/models/vit
Expand Down
1 change: 1 addition & 0 deletions adapter_docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The table below further shows which model architectures support which adaptation
| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Whisper](classes/models/whisper.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,8 @@
"MAMConfig",
"MBartAdapterModel",
"MBartModelWithHeads",
"WhisperAdapterModel",
"WhisperModelWithHeads",
"ModelAdaptersConfig",
"ModelAdaptersMixin",
"ModelWithFlexibleHeadsAdaptersMixin",
Expand Down Expand Up @@ -5708,6 +5710,8 @@
MAMConfig,
MBartAdapterModel,
MBartModelWithHeads,
WhisperAdapterModel,
WhisperModelWithHeads,
ModelAdaptersConfig,
ModelAdaptersMixin,
ModelWithFlexibleHeadsAdaptersMixin,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@
"MBartAdapterModel",
"MBartModelWithHeads",
],
"models.whisper": [
"WhisperAdapterModel",
"WhisperModelWithHeads",
],
"models.roberta": [
"RobertaAdapterModel",
"RobertaModelWithHeads",
Expand Down Expand Up @@ -219,6 +223,7 @@
from .models.gpt2 import GPT2AdapterModel, GPT2ModelWithHeads
from .models.gptj import GPTJAdapterModel
from .models.mbart import MBartAdapterModel, MBartModelWithHeads
from .models.whisper import WhisperAdapterModel, WhisperModelWithHeads
from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads
from .models.t5 import T5AdapterModel, T5ModelWithHeads
from .models.vit import ViTAdapterModel
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], b
"deberta",
"bart",
"mbart",
"whisper",
"gpt2",
"gptj",
"t5",
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@
},
"layers": ["lm_head"],
},
# Whisper
"WhisperForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["proj_out"],
},
# DistilBERT
"DistilBertForSequenceClassification": {
"config": {
Expand Down
51 changes: 51 additions & 0 deletions src/transformers/adapters/mixins/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


class WhisperEncoderLayerAdaptersMixin:
"""Adds adapters to the WhisperEncoderLayer module of WHISPER."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()


class WhisperDecoderLayerAdaptersMixin(WhisperEncoderLayerAdaptersMixin):
"""Adds adapters to the WhisperDecoderLayer module of WHISPER."""

def _init_adapter_modules(self):
super()._init_adapter_modules()
self.cross_attention_adapters = AdapterLayer("cross_adapter", self.config)
self.cross_attention_adapters._init_adapter_modules()


class WhisperModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelAdaptersMixin):
"""Adds adapters to the WhisperModel class."""

invertible_adapters_base_name = "encoder"

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
if hasattr(self, "encoder"):
for i, layer in enumerate(self.encoder.layers):
yield i, layer
for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)):
yield i, layer
else:
for i, layer in enumerate(self.decoder.layers):
yield i, layer


class WhisperModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
("deberta", "DebertaAdapterModel"),
("bart", "BartAdapterModel"),
("mbart", "MBartAdapterModel"),
("whisper", "WhisperAdapterModel"),
("gpt2", "GPT2AdapterModel"),
("gptj", "GPTJAdapterModel"),
("t5", "T5AdapterModel"),
Expand All @@ -33,6 +34,7 @@
("distilbert", "DistilBertModelWithHeads"),
("bart", "BartModelWithHeads"),
("mbart", "MBartModelWithHeads"),
("whisper", "WhisperModelWithHeads"),
("gpt2", "GPT2ModelWithHeads"),
("t5", "T5ModelWithHeads"),
]
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/adapters/models/whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import TYPE_CHECKING

from ....utils import _LazyModule


_import_structure = {
"adapter_model": [
"WhisperAdapterModel",
"WhisperModelWithHeads",
],
}


if TYPE_CHECKING:
from .adapter_model import WhisperAdapterModel, WhisperModelWithHeads

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
Loading
Loading