Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Transformers to v4.36.x #617

Merged
merged 12 commits into from
Feb 17, 2024
2 changes: 0 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ quality:
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
flake8 $(check_dirs)
doc-builder style src/adapters docs --max_len 119 --check_only --path_to_docs docs
python utils/check_inits.py

# Format source code automatically and check is there are any problems left that need manual fixing

extra_style_checks:
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/adapters docs --max_len 119 --path_to_docs docs

# this target runs checks on all files and potentially modifies some of them

Expand Down
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 1414 files
12 changes: 4 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
# We try to follow their general layout wherever sensible.

_deps = [
"accelerate>=0.20.3",
"accelerate>=0.21.0",
"beautifulsoup4",
"black==22.3", # after updating to black 2023, also update Python version in pyproject.toml to 3.7
"datasets!=2.5.0",
"dill<0.3.5",
"docutils==0.16.0",
"evaluate>=0.2.0",
"flake8>=3.8.3",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"isort>=5.5.4",
"Jinja2==2.11.3",
"nltk",
Expand All @@ -49,7 +49,6 @@
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
"safetensors>=0.2.1",
"scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92",
"sphinx-copybutton",
Expand All @@ -61,8 +60,7 @@
"sphinx-multiversion",
"timeout-decorator",
"torch>=1.10,!=1.12.0",
"transformers~=4.35.2",
"beautifulsoup4",
"transformers~=4.36.0",
]


Expand Down Expand Up @@ -103,17 +101,15 @@ def deps_list(*pkgs):
"rouge-score",
"nltk",
"GitPython",
"hf-doc-builder",
"protobuf", # Can be removed once we can unpin protobuf
"sacremoses",
"rjieba",
"safetensors",
"beautifulsoup4",
"pillow",
"accelerate",
)

extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython", "hf-doc-builder")
extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython")

extras["docs"] = deps_list(
"docutils",
Expand Down
6 changes: 5 additions & 1 deletion src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def build(self, model):
for i, module in enumerate(pred_head):
self.add_module(str(i), module)

self.apply(model._init_weights)
# We need to import the current value of _init_weights at each execution to determine if weights init is disabled.
from transformers.modeling_utils import _init_weights

if _init_weights:
self.apply(model._init_weights)
self.train(model.training) # make sure training mode is consistent

def get_output_embeddings(self):
Expand Down
240 changes: 235 additions & 5 deletions src/adapters/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
from torch import nn

from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin


logger = logging.get_logger(__name__)


class BartAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down Expand Up @@ -96,10 +100,9 @@ def forward(
bsz = query_states.size(0)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)

query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)

src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
Expand Down Expand Up @@ -145,22 +148,249 @@ def forward(

if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, attn_weights_reshaped, past_key_value


class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# BartFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("BartFlashAttention2 attention does not support output_attentions")

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, q_len, _ = hidden_states.size()

# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)

query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))

key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
bsz = query_states.size(0)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
)

attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.out_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not"
" support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
" implementation, but specifying the manual implementation will be required from Transformers version"
' v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when'
" loading the model."
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
bsz = query_states.size(0)

query_states = self._shape(query_states, tgt_len, bsz)

# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
)

if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value


class BartEncoderLayerWithAdapters(BartEncoderLayerAdaptersMixin, BartEncoderLayer):
def forward(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/beit/mixin_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:

def set_input_embeddings(self, value):
self.embeddings.patch_embeddings = value

def post_embedding_forward(self, module, args, outputs):
embedding_output, tup = outputs
return super().post_embedding_forward(module, args, embedding_output), tup
Loading
Loading