Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany committed Jun 7, 2024
1 parent 5f5d205 commit 22860f2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
12 changes: 11 additions & 1 deletion optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
LlamaRMSNorm,
)

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
Expand All @@ -29,6 +29,10 @@
)


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.41.2"

_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
_IPEX_EXPORTED_TASK = ("text-generation",)

Expand Down Expand Up @@ -63,6 +67,12 @@ def patch_op(m, target_m, new_op_name, new_op):
def _patch_llama_model(model):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching")
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(
">", _TRANSFORMERS_MAX_VERSION
):
raise ImportError(
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
Expand Down
34 changes: 19 additions & 15 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@
from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import repeat_kv
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.modeling_utils import setattr_from_module
from optimum.intel.utils.modeling_utils import _setattr_from_module


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.41.2"
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0"


Expand Down Expand Up @@ -142,12 +139,12 @@ def _llama_model_forward(
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
class _IPEXLlamaAttention(nn.Module):
def __init__(self, module, config, distributed=False) -> None:
if is_ipex_version("<", "2.3.0"):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
"Only ipex version > 2.3.0 supports LinearAdd, IndirectAccessKVCacheAttention, RotaryEmbedding"
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding"
)
super().__init__()
setattr_from_module(self, module)
_setattr_from_module(self, module)
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding
Expand Down Expand Up @@ -200,7 +197,7 @@ def rope(self, query, key, kv_seq_len, position_ids, use_cache):
)
return query, key

def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask):
def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, position_ids):
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
Expand All @@ -215,11 +212,14 @@ def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask):
return attn_output, past_key_value, attn_weights

# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341
def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask):
def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, position_ids):
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

past_key_value = None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down Expand Up @@ -277,7 +277,9 @@ def forward(
query, key = self.rope(query, key, kv_seq_len, position_ids, use_cache)

sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache
attn_output, past_key_value, attn_weights = sdpa(query, key, value, past_key_value, attention_mask)
attn_output, past_key_value, attn_weights = sdpa(
query, key, value, past_key_value, attention_mask, position_ids
)
attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size)

if hasattr(self, "mha_linear_add"):
Expand All @@ -295,10 +297,12 @@ def forward(
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186
class _IPEXLlamaMLP(nn.Module):
def __init__(self, module, config, distributed=False) -> None:
if is_ipex_version("<", "2.3.0"):
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul, LinearAdd"
)
super().__init__()
setattr_from_module(self, module)
_setattr_from_module(self, module)
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd
Expand Down Expand Up @@ -334,7 +338,7 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **
class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, config, distributed=False):
super().__init__()
setattr_from_module(self, module)
_setattr_from_module(self, module)
self.distributed = distributed
self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed)
self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed)
Expand Down

0 comments on commit 22860f2

Please sign in to comment.