From 013cdc287f98ee4541e6149336f1bc9ff8984947 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 11 Jul 2024 17:53:33 +0400 Subject: [PATCH] Fix model patching for internlm2 (#814) --- optimum/exporters/openvino/model_patcher.py | 36 +++++++++++++-------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 7812682b8b..d7b52315ca 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1073,10 +1073,13 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) - def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors.""" - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + if position_ids is not None: + cos = cos[position_ids] + sin = sin[position_ids] + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -1113,17 +1116,24 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: value_states = value_states.transpose(1, 2) kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + is_legacy = not hasattr(self, "layer_idx") + + if is_legacy: + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) if use_cache else None + else: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - past_key_value = (key_states, value_states) if use_cache else None + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": kwargs.get("cache_position")} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)