From ce847b24be0db3816a09eb6654a5539dee51f428 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 13 Dec 2023 21:13:02 +0100 Subject: [PATCH 1/9] Upgrade Transformers to v4.36.0 --- Makefile | 2 -- hf_transformers | 2 +- setup.py | 12 ++++-------- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index a4333c663..04b937413 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,6 @@ quality: python utils/custom_init_isort.py --check_only python utils/sort_auto_mappings.py --check_only flake8 $(check_dirs) - doc-builder style src/adapters docs --max_len 119 --check_only --path_to_docs docs python utils/check_inits.py # Format source code automatically and check is there are any problems left that need manual fixing @@ -21,7 +20,6 @@ quality: extra_style_checks: python utils/custom_init_isort.py python utils/sort_auto_mappings.py - doc-builder style src/adapters docs --max_len 119 --path_to_docs docs # this target runs checks on all files and potentially modifies some of them diff --git a/hf_transformers b/hf_transformers index 514de24ab..14666775a 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 514de24abfd4416aeba6a6455ad5920f57f3567d +Subproject commit 14666775a296a76c88e1aa686a9547f393d322e2 diff --git a/setup.py b/setup.py index b4c89ccd6..4c7e62b2a 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,8 @@ # We try to follow their general layout wherever sensible. _deps = [ - "accelerate>=0.20.3", + "accelerate>=0.21.0", + "beautifulsoup4", "black==22.3", # after updating to black 2023, also update Python version in pyproject.toml to 3.7 "datasets!=2.5.0", "dill<0.3.5", @@ -29,7 +30,6 @@ "evaluate>=0.2.0", "flake8>=3.8.3", "GitPython<3.1.19", - "hf-doc-builder>=0.3.0", "isort>=5.5.4", "Jinja2==2.11.3", "nltk", @@ -49,7 +49,6 @@ "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "sacrebleu>=1.4.12,<2.0.0", "sacremoses", - "safetensors>=0.2.1", "scikit-learn", "sentencepiece>=0.1.91,!=0.1.92", "sphinx-copybutton", @@ -61,8 +60,7 @@ "sphinx-multiversion", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers==4.35.2", - "beautifulsoup4", + "transformers==4.36.0", ] @@ -103,17 +101,15 @@ def deps_list(*pkgs): "rouge-score", "nltk", "GitPython", - "hf-doc-builder", "protobuf", # Can be removed once we can unpin protobuf "sacremoses", "rjieba", - "safetensors", "beautifulsoup4", "pillow", "accelerate", ) -extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython", "hf-doc-builder") +extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython") extras["docs"] = deps_list( "docutils", From 851dda64579236925d77c85cb99692fffe229f0a Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 16 Dec 2023 17:38:31 +0100 Subject: [PATCH 2/9] Upgrade to v4.36.1 --- hf_transformers | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hf_transformers b/hf_transformers index 14666775a..c48787f34 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 14666775a296a76c88e1aa686a9547f393d322e2 +Subproject commit c48787f347bd604f656c2cfff730e029c8f8c1fe diff --git a/setup.py b/setup.py index 4c7e62b2a..17bfcb58c 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ "sphinx-multiversion", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers==4.36.0", + "transformers==4.36.1", ] From d55b79bcead63b52953d7735ccf09eb79994390d Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 16 Dec 2023 18:13:30 +0100 Subject: [PATCH 3/9] Fix Llama & Beit for Transformers v4.36 --- src/adapters/models/beit/mixin_beit.py | 4 ++ src/adapters/models/llama/modeling_llama.py | 70 ++++++++++++++++----- tests/test_llama.py | 2 +- 3 files changed, 61 insertions(+), 15 deletions(-) diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index 608d9c5cc..d4490a33d 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -55,3 +55,7 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: def set_input_embeddings(self, value): self.embeddings.patch_embeddings = value + + def post_embedding_forward(self, module, args, outputs): + embedding_output, tup = outputs + return super().post_embedding_forward(module, args, embedding_output), tup diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 3b22e5ae1..187a87439 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -19,9 +19,11 @@ # limitations under the License. """ PyTorch LLaMA model.""" import math +import warnings from typing import Optional, Tuple import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -30,7 +32,8 @@ adjust_tensors_for_parallel_, match_attn_matrices_for_parallel, ) -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin @@ -47,15 +50,43 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states @@ -64,17 +95,22 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = (key_states, value_states) if use_cache else None + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask @@ -99,10 +135,10 @@ def forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -111,10 +147,16 @@ def forward( f" {attn_output.size()}" ) - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None diff --git a/tests/test_llama.py b/tests/test_llama.py index 2fd455c17..9cb6fcfda 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -56,7 +56,7 @@ class LlamaAdapterTest( @require_torch -class BertClassConversionTest( +class LlamaClassConversionTest( ModelClassConversionTestMixin, LlamaAdapterTestBase, unittest.TestCase, From 6e79c7d7cc7bdea0dc1a1bbe4273e5cd8f236a07 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 16 Dec 2023 18:16:42 +0100 Subject: [PATCH 4/9] style --- src/adapters/models/llama/modeling_llama.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 187a87439..990a4607b 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -57,7 +57,8 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" + " `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() @@ -97,9 +98,9 @@ def forward( if past_key_value is not None: if self.layer_idx is None: raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." + "The cache structure has changed since version v4.36. If you are using" + f" {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to" + " initialize the attention class with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) From 767e868c19efa808cc3aa21c0150146808332691 Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 18 Dec 2023 22:31:05 +0100 Subject: [PATCH 5/9] Upgrade to Transformers v4.36.2 --- hf_transformers | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hf_transformers b/hf_transformers index c48787f34..a7cab3c28 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit c48787f347bd604f656c2cfff730e029c8f8c1fe +Subproject commit a7cab3c283312b8d4de5df3bbe719971e24f4281 diff --git a/setup.py b/setup.py index 17bfcb58c..0cf38c07c 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ "sphinx-multiversion", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers==4.36.1", + "transformers~=4.36.0", ] From a96a49f615e566e934584b7042b63ca695cdc3c3 Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 5 Jan 2024 23:32:03 +0100 Subject: [PATCH 6/9] Llama adapter modules base classes --- src/adapters/models/llama/modeling_llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 990a4607b..c1fc25c7b 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -33,7 +33,7 @@ match_attn_matrices_for_parallel, ) from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin @@ -42,7 +42,7 @@ logger = logging.get_logger(__name__) -class LlamaAttentionWithAdapters(nn.Module, LlamaAttentionMixin): +class LlamaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def forward( @@ -165,7 +165,7 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayerWithAdapters(nn.Module, LlamaDecoderLayerMixin): +class LlamaDecoderLayerWithAdapters(LlamaDecoderLayerMixin, LlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, From aea8a2dad5df22689cb3b13fd7ca42aa7bd7dc5c Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 21 Jan 2024 18:43:14 +0100 Subject: [PATCH 7/9] Support fast attention implementations for Bart & Llama --- src/adapters/models/bart/modeling_bart.py | 238 +++++++++++++++++++- src/adapters/models/llama/modeling_llama.py | 212 ++++++++++++++++- 2 files changed, 439 insertions(+), 11 deletions(-) diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index 28bf37bd7..1b10d3fb6 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -19,12 +19,16 @@ import torch.utils.checkpoint from torch import nn +from transformers.utils import logging from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin +logger = logging.get_logger(__name__) + + class BartAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -96,10 +100,9 @@ def forward( bsz = query_states.size(0) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) @@ -145,7 +148,7 @@ def forward( if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) @@ -153,7 +156,7 @@ def forward( attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. + # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) @@ -161,6 +164,231 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartAttention): + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("BartFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention): + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + bsz = query_states.size(0) + + query_states = self._shape(query_states, tgt_len, bsz) + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + class BartEncoderLayerWithAdapters(BartEncoderLayerAdaptersMixin, BartEncoderLayer): def forward( self, diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index c1fc25c7b..ab7acdb35 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -57,8 +57,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" - " `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() @@ -98,9 +97,9 @@ def forward( if past_key_value is not None: if self.layer_idx is None: raise ValueError( - "The cache structure has changed since version v4.36. If you are using" - f" {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to" - " initialize the attention class with a layer index." + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -122,7 +121,7 @@ def forward( # Make adjustments since (parallel) prefix tuning changes the attention mask kv_seq_len = key_states.shape[-2] - bsz = attention_mask.shape[0] + bsz = key_states.shape[0] if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -165,6 +164,207 @@ def forward( return attn_output, attn_weights, past_key_value +class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + + # Make adjustments since (parallel) prefix tuning changes the attention mask + kv_seq_len = key_states.shape[-2] + bsz = key_states.shape[0] + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaSdpaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention): + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + + # Make adjustments since (parallel) prefix tuning changes the attention mask + kv_seq_len = key_states.shape[-2] + bsz = key_states.shape[0] + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + class LlamaDecoderLayerWithAdapters(LlamaDecoderLayerMixin, LlamaDecoderLayer): def forward( self, From 269a15250f987e97a8672c4ff6a159b545c96450 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 21 Jan 2024 18:46:08 +0100 Subject: [PATCH 8/9] style --- src/adapters/models/bart/modeling_bart.py | 14 +++++++------ src/adapters/models/llama/modeling_llama.py | 23 ++++++++++++--------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index 1b10d3fb6..b347fddf0 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -19,8 +19,8 @@ import torch.utils.checkpoint from torch import nn -from transformers.utils import logging from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer +from transformers.utils import logging from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -165,7 +165,6 @@ def forward( class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartAttention): - def forward( self, hidden_states: torch.Tensor, @@ -254,8 +253,8 @@ def forward( target_dtype = self.q_proj.weight.dtype logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) @@ -290,8 +289,11 @@ def forward( if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. logger.warning_once( - "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not" + " support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + " implementation, but specifying the manual implementation will be required from Transformers version" + ' v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when' + " loading the model." ) return super().forward( hidden_states, diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index ab7acdb35..baefac75a 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -57,7 +57,8 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" + " `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() @@ -97,9 +98,9 @@ def forward( if past_key_value is not None: if self.layer_idx is None: raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." + "The cache structure has changed since version v4.36. If you are using" + f" {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to" + " initialize the attention class with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -165,7 +166,6 @@ def forward( class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaAttention): - def forward( self, hidden_states: torch.Tensor, @@ -179,7 +179,8 @@ def forward( # LlamaFlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" + " `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask @@ -247,8 +248,8 @@ def forward( target_dtype = self.q_proj.weight.dtype logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) @@ -284,8 +285,10 @@ def forward( if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does" + " not support `output_attentions=True`. Falling back to the manual attention implementation, but" + " specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This" + ' warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, From de8e8af8f4486628a1544a47e0297b5848055098 Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 12 Feb 2024 22:24:55 +0100 Subject: [PATCH 9/9] Fix _init_weights during adapter head loading. --- src/adapters/heads/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/adapters/heads/base.py b/src/adapters/heads/base.py index 343b2774a..0f35752b0 100644 --- a/src/adapters/heads/base.py +++ b/src/adapters/heads/base.py @@ -102,7 +102,11 @@ def build(self, model): for i, module in enumerate(pred_head): self.add_module(str(i), module) - self.apply(model._init_weights) + # We need to import the current value of _init_weights at each execution to determine if weights init is disabled. + from transformers.modeling_utils import _init_weights + + if _init_weights: + self.apply(model._init_weights) self.train(model.training) # make sure training mode is consistent def get_output_embeddings(self):