Skip to content

Commit

Permalink
refactor attention
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Jan 14, 2025
1 parent daddabf commit 95b7043
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,9 @@ def has_flash_attn(self, query):
elif query.device.type == "xpu":
return is_torch_version(">", "2.5.99")

def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens):
def attention_interface(
self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
):
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -650,15 +652,15 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn(query):
elif self.has_flash_attn(query) and past_len == 0:
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
key_cache,
value_cache,
query.contiguous() if query.device.type == "xpu" else query,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
Expand All @@ -668,7 +670,7 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value
past_key_value.block_tables,
None,
)
else:
elif past_len == 0:
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
Expand All @@ -688,6 +690,22 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value
False,
None,
)
else:
# decode
attn_output = torch.empty_like(query)
PagedAttention.single_query_cached_kv_attention(
attn_output,
query,
key_cache,
value_cache,
self.kv_head_mapping,
1.0 / math.sqrt(self.head_dim),
past_key_value.block_tables,
input_lens,
past_key_value.block_size,
input_lens.max(),
None,
)

return attn_output

Expand All @@ -714,27 +732,9 @@ def forward(
if past_key_value is not None:
key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens)

if past_len == 0:
# prefill
attn_output = self.prefill_attn(
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens
)
else:
# decode
attn_output = torch.empty_like(query)
PagedAttention.single_query_cached_kv_attention(
attn_output,
query,
key_cache,
value_cache,
self.kv_head_mapping,
1.0 / math.sqrt(self.head_dim),
past_key_value.block_tables,
input_lens,
past_key_value.block_size,
input_lens.max(),
None,
)
attn_output = self.attention_interface(
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len
)

attn_output = self.postprocess_attention_output(attn_output)
if not output_attentions:
Expand Down

0 comments on commit 95b7043

Please sign in to comment.