Skip to content

Commit

Permalink
add apply rotary back
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany committed Jun 6, 2024
1 parent ae51e3b commit e17c29d
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import repeat_kv
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.modeling_utils import _setattr_from_module
Expand Down Expand Up @@ -200,7 +200,7 @@ def rope(self, query, key, kv_seq_len, position_ids, use_cache):
)
return query, key

def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask):
def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids):
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
Expand All @@ -215,11 +215,14 @@ def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask):
return attn_output, past_key_value, attn_weights

# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341
def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask):
def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, position_ids):
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

past_key_value = None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down Expand Up @@ -274,7 +277,7 @@ def forward(
kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len

query, key, value = self.qkv_gemm(hidden_states)
query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache)
query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache, position_ids)

sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache
attn_output, past_key_value, attn_weights = sdpa(query, key, value, past_key_value, attention_mask)
Expand Down

0 comments on commit e17c29d

Please sign in to comment.