Skip to content

Commit

Permalink
Upgrade Transformers to v4.36.x (#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Feb 17, 2024
1 parent c83ed10 commit 26ea2c6
Show file tree
Hide file tree
Showing 8 changed files with 512 additions and 34 deletions.
2 changes: 0 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ quality:
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
flake8 $(check_dirs)
doc-builder style src/adapters docs --max_len 119 --check_only --path_to_docs docs
python utils/check_inits.py

# Format source code automatically and check is there are any problems left that need manual fixing

extra_style_checks:
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/adapters docs --max_len 119 --path_to_docs docs

# this target runs checks on all files and potentially modifies some of them

Expand Down
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 1414 files
12 changes: 4 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
# We try to follow their general layout wherever sensible.

_deps = [
"accelerate>=0.20.3",
"accelerate>=0.21.0",
"beautifulsoup4",
"black==22.3", # after updating to black 2023, also update Python version in pyproject.toml to 3.7
"datasets!=2.5.0",
"dill<0.3.5",
"docutils==0.16.0",
"evaluate>=0.2.0",
"flake8>=3.8.3",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"isort>=5.5.4",
"Jinja2==2.11.3",
"nltk",
Expand All @@ -49,7 +49,6 @@
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
"safetensors>=0.2.1",
"scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92",
"sphinx-copybutton",
Expand All @@ -61,8 +60,7 @@
"sphinx-multiversion",
"timeout-decorator",
"torch>=1.10,!=1.12.0",
"transformers~=4.35.2",
"beautifulsoup4",
"transformers~=4.36.0",
]


Expand Down Expand Up @@ -103,17 +101,15 @@ def deps_list(*pkgs):
"rouge-score",
"nltk",
"GitPython",
"hf-doc-builder",
"protobuf", # Can be removed once we can unpin protobuf
"sacremoses",
"rjieba",
"safetensors",
"beautifulsoup4",
"pillow",
"accelerate",
)

extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython", "hf-doc-builder")
extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython")

extras["docs"] = deps_list(
"docutils",
Expand Down
6 changes: 5 additions & 1 deletion src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def build(self, model):
for i, module in enumerate(pred_head):
self.add_module(str(i), module)

self.apply(model._init_weights)
# We need to import the current value of _init_weights at each execution to determine if weights init is disabled.
from transformers.modeling_utils import _init_weights

if _init_weights:
self.apply(model._init_weights)
self.train(model.training) # make sure training mode is consistent

def get_output_embeddings(self):
Expand Down
240 changes: 235 additions & 5 deletions src/adapters/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
from torch import nn

from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin


logger = logging.get_logger(__name__)


class BartAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down Expand Up @@ -96,10 +100,9 @@ def forward(
bsz = query_states.size(0)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)

query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)

src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
Expand Down Expand Up @@ -145,22 +148,249 @@ def forward(

if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, attn_weights_reshaped, past_key_value


class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# BartFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("BartFlashAttention2 attention does not support output_attentions")

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, q_len, _ = hidden_states.size()

# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)

query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))

key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
bsz = query_states.size(0)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)

attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not"
" support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
" implementation, but specifying the manual implementation will be required from Transformers version"
' v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when'
" loading the model."
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
bsz = query_states.size(0)

query_states = self._shape(query_states, tgt_len, bsz)

# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)

if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value


class BartEncoderLayerWithAdapters(BartEncoderLayerAdaptersMixin, BartEncoderLayer):
def forward(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/beit/mixin_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:

def set_input_embeddings(self, value):
self.embeddings.patch_embeddings = value

def post_embedding_forward(self, module, args, outputs):
embedding_output, tup = outputs
return super().post_embedding_forward(module, args, embedding_output), tup
Loading

0 comments on commit 26ea2c6

Please sign in to comment.