From 95b7043d592100702b04464529e1777ddbbab5b9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 13:54:17 +0000 Subject: [PATCH] refactor attention Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 54 ++++++++++++------------ 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 1bf605ca7..46f3868cc 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -634,7 +634,9 @@ def has_flash_attn(self, query): elif query.device.type == "xpu": return is_torch_version(">", "2.5.99") - def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens): + def attention_interface( + self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + ): if past_key_value is None: n_rep = query.shape[1] // key.shape[1] attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -650,15 +652,15 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query): + elif self.has_flash_attn(query) and past_len == 0: # prefill, remove padding attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) PagedAttention.flash_attn_varlen_func( attn_output, - query, - key_cache, - value_cache, + query.contiguous() if query.device.type == "xpu" else query, + key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, + value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, seq_len_tensor, seq_len_tensor, input_lens.max(), @@ -668,7 +670,7 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value past_key_value.block_tables, None, ) - else: + elif past_len == 0: # prefill, remove padding attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) @@ -688,6 +690,22 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value False, None, ) + else: + # decode + attn_output = torch.empty_like(query) + PagedAttention.single_query_cached_kv_attention( + attn_output, + query, + key_cache, + value_cache, + self.kv_head_mapping, + 1.0 / math.sqrt(self.head_dim), + past_key_value.block_tables, + input_lens, + past_key_value.block_size, + input_lens.max(), + None, + ) return attn_output @@ -714,27 +732,9 @@ def forward( if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) - if past_len == 0: - # prefill - attn_output = self.prefill_attn( - query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens - ) - else: - # decode - attn_output = torch.empty_like(query) - PagedAttention.single_query_cached_kv_attention( - attn_output, - query, - key_cache, - value_cache, - self.kv_head_mapping, - 1.0 / math.sqrt(self.head_dim), - past_key_value.block_tables, - input_lens, - past_key_value.block_size, - input_lens.max(), - None, - ) + attn_output = self.attention_interface( + query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + ) attn_output = self.postprocess_attention_output(attn_output) if not output_attentions: