Skip to content

Commit

Permalink
use with and without cache
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany committed Jun 3, 2024
1 parent a2a969e commit 3abd790
Showing 1 changed file with 57 additions and 53 deletions.
110 changes: 57 additions & 53 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,60 +172,60 @@ def qkv_gemm(self, hidden_states):

return query, key, value

def rope(self, query, key, kv_seq_len, position_ids, use_cache):
if use_cache:
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
def rope(self, query, key, kv_seq_len, position_ids):
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
return query, key

def sdpa(self, query, key, value, past_key_value, attention_mask, use_cache):
if use_cache:
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
None,
attention_mask,
)
else:
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask):
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
query,
key,
value,
math.sqrt(self.head_dim),
past_key_value,
None,
attention_mask,
)
return attn_output, past_key_value, attn_weights

def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask):
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)

past_key_value = None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
past_key_value = None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
if attention_mask is not None:
attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

return attn_output, past_key_value, attn_weights

Expand Down Expand Up @@ -266,11 +266,15 @@ def forward(
kv_seq_len = seq_len + past_key_value[0].size(-2) if past_key_value is not None else seq_len

query, key, value = self.qkv_gemm(hidden_states)
query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache)

attn_output, past_key_value, attn_weights = self.sdpa(
query, key, value, past_key_value, attention_mask, use_cache
)
if use_cache:
query, key = self.rope(query, key, kv_seq_len, position_ids)
attn_output, past_key_value, attn_weights = self.sdpa_with_cache(
query, key, value, past_key_value, attention_mask
)
else:
attn_output, past_key_value, attn_weights = self.sdpa_without_cache(
query, key, value, past_key_value, attention_mask
)
attn_output = attn_output.transpose(1, 2).view(bsz, seq_len, self.hidden_size)

if hasattr(self, "mha_linear_add"):
Expand Down

0 comments on commit 3abd790

Please sign in to comment.