Skip to content

Commit

Permalink
Enable ipex patching for v2.3 (#725)
Browse files Browse the repository at this point in the history
* ipex 2.3 released

* skip tests

* skip testing without pkv

* add tests skip

* only llama2 with at least 64 head size support IAKV

* cannot assert same outputs cause do_sample=True

* rm tiny-llama model testing cause it not work for IAKV

* fix code style

* fix style

* rm tiny llama on test pipeline

* fix tests

* support use_cache=False

* rm use_cache in model_kwargs

* set use_cache

* Update optimum/intel/ipex/modeling_base.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* fix spelling error

* fix style

* add transformers version warning

* add compare resultes

* add warning

* set pad_token_id

* limited transformers

* fix transformers version

* update transformers version

* fix version

* temporary fix for multi-query model

* fix code styke

* add transformers version tests

* Update .github/workflows/test_ipex.yml

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* check geenration method

* Update optimum/intel/ipex/modeling_base.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* fix use_cache

* add hidden size limitation for patch

* add llama in tests

* add re-load tests

* fix hidden size check

* rm norm config

* add version variable

* fix import

* rm useless logger

* rm useless logging

* fix last round review

* Update .github/workflows/test_ipex.yml

* Update optimum/intel/ipex/modeling_base.py

* Update optimum/intel/ipex/modeling_base.py

* Update setup.py

* Update optimum/exporters/ipex/modeling_utils.py

* fix

* limit the new tokens of assisted decoding tests

---------

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
  • Loading branch information
jiqing-feng and echarlaix authored Jun 6, 2024
1 parent d5dbb3d commit f06f504
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 56 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
transformers-version: [4.39.0, 4.41.2]
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand All @@ -32,6 +33,7 @@ jobs:
python -m pip install --upgrade pip
pip install torch torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests]
pip install transformers==${{ matrix.transformers-version }}
- name: Test with Pytest
run: |
pytest tests/ipex/
11 changes: 7 additions & 4 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from optimum.intel.utils.import_utils import is_ipex_version

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXLlamaDecoderLayerRef,
_llama_attn_forward,
_llama_layer_norm_forward,
Expand Down Expand Up @@ -62,18 +63,20 @@ def patch_op(m, target_m, new_op_name, new_op):


def _patch_llama_model(model):
if is_ipex_version("<", "2.5.0"):
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports RotaryEmbedding and IndirectAccessKVCacheAttention"
)

from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, RotaryEmbedding

ipex_rope = RotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=model.config.max_position_embeddings)
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)

Expand Down
77 changes: 48 additions & 29 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
from torch import nn
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import repeat_kv
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from optimum.intel.utils.import_utils import is_ipex_version
from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.41.2"
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0"


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
Expand Down Expand Up @@ -51,27 +57,27 @@ def _llama_attn_forward(
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# Use ipex op to rotary position embedding more efficient.
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)

if use_cache:
# Use ipex op to rotary position embedding more efficient.
key = self.ipex_rope(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self.ipex_rope(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
# This ipex op pre-allocates buffers for past_key_values and use beam index history
# which to decide which beam should be used to make attention scale dot more efficient.
(attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
Expand All @@ -87,6 +93,8 @@ def _llama_attn_forward(
value_states = value.transpose(1, 2)
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
kv_seq_len = key_states.shape[-2]

past_key_value = None
Expand Down Expand Up @@ -219,8 +227,16 @@ def _llama_model_forward(
# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayerRef(nn.Module):
def __init__(self, module, config, distributed=False):
if is_ipex_version("<", "2.5.0"):
raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul and LinearAdd"
)
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(
">", _TRANSFORMERS_MAX_VERSION
):
raise ImportError(
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)

from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd

Expand Down Expand Up @@ -278,7 +294,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
)
if not self.distributed:
if hasattr(self, "mha_linear_add"):
hidden_states = self.mha_linear_add(hidden_states, residual)
else:
hidden_states = self.self_attn.o_proj(hidden_states)
Expand All @@ -288,12 +304,15 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

mlp_gate = self.linear_silu_mul(hidden_states)

if not self.distributed:
hidden_states = self.mlp_linear_add(mlp_gate, residual)
if hasattr(self, "linear_silu_mul"):
mlp_gate = self.linear_silu_mul(hidden_states)
if hasattr(self, "mlp_linear_add"):
hidden_states = self.mlp_linear_add(mlp_gate, residual)
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = residual + hidden_states
else:
hidden_states = self.mlp.down_proj(mlp_gate)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down
58 changes: 51 additions & 7 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import intel_extension_for_pytorch as ipex
import torch
Expand Down Expand Up @@ -50,7 +50,7 @@
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model
from ..generation.modeling import prepare_jit_inputs
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device
Expand All @@ -60,10 +60,11 @@


_IPEX_SUPPORT_MODEL_TYPES = ("llama",)
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")


def _is_patched_with_ipex(model, task):
if is_ipex_version("<", "2.5.0"):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
return False

if isinstance(model, torch.jit.ScriptModule):
Expand All @@ -73,7 +74,12 @@ def _is_patched_with_ipex(model, task):
return True
return False
else:
return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK
# The ipex IAKV op in patched model requires the hidden size at least 64
return (
model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES
and task in _IPEX_EXPORTED_TASK
and model.config.hidden_size >= 64
)


def ipex_jit_trace(model, task, use_cache):
Expand All @@ -83,6 +89,7 @@ def ipex_jit_trace(model, task, use_cache):

if _is_patched_with_ipex(model, task):
model = _patch_model(model)
# Todo: integerate in prepare_jit_inputs.
sample_inputs = get_dummy_input(model, return_dict=True)
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
_enable_tpp()
Expand All @@ -92,9 +99,10 @@ def ipex_jit_trace(model, task, use_cache):

model.config.return_dict = False

if "past_key_values" in sample_inputs and use_cache:
# Make sure the model will output past_key_values in generation tasks
model.config.use_cache = True
if "past_key_values" in sample_inputs:
model.config.use_cache = use_cache
if not use_cache:
sample_inputs.pop("past_key_values")

model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
# Disable repack while jit tracing to reduce the memory
Expand Down Expand Up @@ -522,6 +530,23 @@ def _prepare_past_key_values(self, input_ids):

return past_key_values

# Temporary fix, will delete when https://github.com/huggingface/transformers/pull/31226 release.
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
if not model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = None
return model_kwargs

past_length = 0
if "past_key_values" in model_kwargs:
past_length = model_kwargs["past_key_values"][0][0].shape[-2]
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -561,6 +586,25 @@ def forward(

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)

def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
generation_method = generation_config.get_generation_mode().value
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
raise ValueError(
f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)

return generation_config, model_kwargs

def generate(self, *args, **kwargs):
if self._is_ipex_exported and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
return super().generate(*args, **kwargs)


def _prepare_inputs_for_generation_for_llama(
input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"neural-compressor": ["neural-compressor>=2.2.0", "onnxruntime<1.15.0", "accelerate"],
"openvino": ["openvino>=2023.3", "nncf>=2.10.0", "openvino-tokenizers[transformers]"],
"nncf": ["nncf>=2.10.0"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.36.0,<4.39.0"],
"ipex": ["intel-extension-for-pytorch", "transformers>=4.39.0,<=4.41.2"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
Loading

0 comments on commit f06f504

Please sign in to comment.