From 29bcbed3016cbb63bce557b29fabf03840b53a59 Mon Sep 17 00:00:00 2001 From: Varun Gumma Date: Sat, 24 Aug 2024 14:19:56 +0000 Subject: [PATCH] bug fix --- .../modules/fast_grouped_query_attention.py | 15 +-- fairseq/modules/transformer_layer.py | 91 ++++++++++--------- 2 files changed, 50 insertions(+), 56 deletions(-) diff --git a/fairseq/modules/fast_grouped_query_attention.py b/fairseq/modules/fast_grouped_query_attention.py index 95f0670784..26f25a7bd2 100644 --- a/fairseq/modules/fast_grouped_query_attention.py +++ b/fairseq/modules/fast_grouped_query_attention.py @@ -247,17 +247,10 @@ def forward( if (self.num_heads != self.num_kv_heads) and k is not None and v is not None: # self.num_heads == self.num_kv_heads * self.q_per_kv - k = rearrange(k, "(b h) t d -> b h 1 t d", h=self.num_kv_heads) - k = k.expand(bsz, self.num_kv_heads, self.q_per_kv, -1, self.head_dim) - k = rearrange( - k, "b h nq t d -> (b h nq) t d", h=self.num_kv_heads, nq=self.q_per_kv - ) - - v = rearrange(v, "(b h) t d -> b h 1 t d", h=self.num_kv_heads) - v = v.expand(bsz, self.num_kv_heads, self.q_per_kv, -1, self.head_dim) - v = rearrange( - v, "b h nq t d -> (b h nq) t d", h=self.num_kv_heads, nq=self.q_per_kv - ) + k = rearrange(k, "(b h) t d -> b h t d", h=self.num_kv_heads) + k = torch.repeat_interleave(k, dim=1, repeats=self.q_per_kv) + v = rearrange(v, "(b h) t d -> b h t d", h=self.num_kv_heads) + v = torch.repeat_interleave(v, dim=1, repeats=self.q_per_kv) if saved_state is not None: # saved states are stored with shape (bsz, num_kv_heads, q_per_kv, seq_len, head_dim) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index d6bc249e8f..7520e2f415 100755 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -150,42 +150,43 @@ def _prune_fc_layer(self, remove_index: List[int]): self.fc2.bias = torch.nn.Parameter(new_fc2_bias) def build_self_attention(self, embed_dim, cfg): - if self.attn_implementation == "fairseq": - return MultiheadAttention( + is_fused = self.attn_implementation.endswith("fused") + + if self.attn_implementation.startswith("fast"): + return FastMultiheadAttention( embed_dim, cfg.encoder.attention_heads, dropout=cfg.attention_dropout, self_attention=True, + fused_qkv=is_fused, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, - xformers_att_config=cfg.encoder.xformers_att_config, + rope_args=getattr(cfg, "rope_args", None), ) - elif self.attn_implementation == "fast" or self.attn_implementation == "fast_fused": - return FastMultiheadAttention( + elif ( + self.attn_implementation.startswith("fast_gqa") + and getattr(cfg.encoder, "kv_attention_heads", None) is not None + ): + return FastGroupedQueryAttention( embed_dim, cfg.encoder.attention_heads, + cfg.encoder.kv_attention_heads, dropout=cfg.attention_dropout, self_attention=True, + fused_qkv=is_fused, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, rope_args=getattr(cfg, "rope_args", None), - fused_qkv=(self.attn_implementation == "fast_fused") ) - elif self.attn_implementation == "fast_gqa" or self.attn_implementation == "fast_gqa_fused": - return FastGroupedQueryAttention( + else: + return MultiheadAttention( embed_dim, cfg.encoder.attention_heads, - cfg.encoder.kv_attention_heads, dropout=cfg.attention_dropout, self_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, - rope_args=getattr(cfg, "rope_args", None), - fused_qkv=(self.attn_implementation == "fast_gqa_fused") - ) - else: - raise NotImplementedError( - f"Unknown attention implementation: {self.attn_implementation}" + xformers_att_config=cfg.encoder.xformers_att_config, ) def build_normalization(self, dim, rms=False): @@ -401,19 +402,9 @@ def build_glu(self, input_dim, intermediate_dim, activation_fn="silu", bias=Fals def build_self_attention( self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False ): - if self.attn_implementation == "fairseq": - return MultiheadAttention( - embed_dim, - cfg.decoder.attention_heads, - dropout=cfg.attention_dropout, - add_bias_kv=add_bias_kv, - add_zero_attn=add_zero_attn, - self_attention=True, - q_noise=self.quant_noise, - qn_block_size=self.quant_noise_block_size, - xformers_att_config=cfg.decoder.xformers_att_config, - ) - elif self.attn_implementation == "fast" or self.attn_implementation == "fast_fused": + is_fused = self.attn_implementation.endswith("fused") + + if self.attn_implementation.startswith("fast"): return FastMultiheadAttention( embed_dim, cfg.decoder.attention_heads, @@ -422,12 +413,15 @@ def build_self_attention( add_zero_attn=add_zero_attn, self_attention=True, is_decoder=True, + fused_qkv=is_fused, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, rope_args=getattr(cfg, "rope_args", None), - fused_qkv=(self.attn_implementation == "fast_fused") ) - elif self.attn_implementation == "fast_gqa" or self.attn_implementation == "fast_gqa_fused": + elif ( + self.attn_implementation.startswith("fast_gqa") + and getattr(cfg.decoder, "kv_attention_heads", None) is not None + ): return FastGroupedQueryAttention( embed_dim, cfg.decoder.attention_heads, @@ -437,30 +431,26 @@ def build_self_attention( add_zero_attn=add_zero_attn, self_attention=True, is_decoder=True, + fused_qkv=is_fused, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, rope_args=getattr(cfg, "rope_args", None), - fused_qkv=(self.attn_implementation == "fast_gqa_fused") ) else: - raise NotImplementedError( - f"Unknown attention implementation: {self.attn_implementation}" - ) - - def build_encoder_attention(self, embed_dim, cfg): - if self.attn_implementation == "fairseq": return MultiheadAttention( embed_dim, cfg.decoder.attention_heads, - kdim=cfg.encoder.embed_dim, - vdim=cfg.encoder.embed_dim, dropout=cfg.attention_dropout, - encoder_decoder_attention=True, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, - xformers_att_config=cfg.encoder.xformers_att_config, + xformers_att_config=cfg.decoder.xformers_att_config, ) - elif self.attn_implementation == "fast" or self.attn_implementation == "fast_fused": + + def build_encoder_attention(self, embed_dim, cfg): + if self.attn_implementation.startswith("fast"): return FastMultiheadAttention( embed_dim, cfg.decoder.attention_heads, @@ -473,7 +463,10 @@ def build_encoder_attention(self, embed_dim, cfg): q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, ) - elif self.attn_implementation == "fast_gqa" or self.attn_implementation == "fast_gqa_fused": + elif ( + self.attn_implementation.startswith("fast_gqa") + and getattr(cfg.decoder, "kv_attention_heads", None) is not None + ): return FastGroupedQueryAttention( embed_dim, cfg.decoder.attention_heads, @@ -488,8 +481,16 @@ def build_encoder_attention(self, embed_dim, cfg): qn_block_size=self.quant_noise_block_size, ) else: - raise NotImplementedError( - f"Unknown attention implementation: {self.attn_implementation}" + return MultiheadAttention( + embed_dim, + cfg.decoder.attention_heads, + kdim=cfg.encoder.embed_dim, + vdim=cfg.encoder.embed_dim, + dropout=cfg.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.encoder.xformers_att_config, ) def prepare_for_onnx_export_(self):