From 1f07200a2659c27f159cf8dda730096a77f9ccce Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Sat, 6 Apr 2024 14:08:33 +0200 Subject: [PATCH] FIX: Add safe guards for static cache + llama on transformers latest (#401) --- awq/modules/fused/attn.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index f1732ea5..9775126b 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -188,16 +188,19 @@ def forward( # Always reset to 0 self.start_pos = 0 + hf_is_generating = False + + if self.is_hf_transformers and "use_cache" in kwargs: + hf_is_generating = kwargs["use_cache"] + + # In case we re-generate, we need to refresh the starting position # to 0. We detect it by checking if `past_key_values` is set to None, # which indicates that we are on the first step of `generate()`. # This is only applicable for `transformers` integration - if ( - self.is_hf_transformers - and "past_key_value" in kwargs - and kwargs["past_key_value"] is None - ): + if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating): self.start_pos = 0 + xqkv = self.qkv_proj(hidden_states) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) @@ -214,8 +217,6 @@ def forward( if not self.use_alibi: xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen) - self.cache.to(xq) - values_store = xv.transpose(2, 1) keys_store = ( xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"]) @@ -223,6 +224,7 @@ def forward( .contiguous() ) + self.cache.to(xq) self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen) # Only necessary to retrieve from cache when we are not processing context @@ -248,6 +250,11 @@ def forward( # When seqlen is 1, there is nothing else to attend to if attention_mask is not None and seqlen > 1: + # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we + # need to slice it + if attention_mask.shape[-1] != seqlen: + attention_mask = attention_mask[:, :, :seqlen, :seqlen] + scores = ( scores + attention_mask ) # (bs, n_local_heads, slen, cache_len + slen) @@ -278,11 +285,15 @@ def forward( attn_output = self.o_proj(attention_weight) self.start_pos += seqlen + if self.is_hf_transformers and not hf_is_generating: + self.start_pos = 0 + # past_key_value is replaced with cache_v, cache_k, returning empty data # we pass a dummy past kv cache for transformers to be able to retrieve the correct info # about past key length past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] + if HF_NEW_CACHE_FORMAT and self.is_hf_transformers: new_cache = DynamicCache() new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)