Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor CPU llama inference code #728

Merged
merged 31 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5351f4a
ipex 2.3 released
jiqing-feng May 23, 2024
d1d0ca0
refactor IPEXLlamaAttention
faaany May 25, 2024
bd5706c
Merge branch 'huggingface:main' into ipex-cpu
faaany May 26, 2024
e61382b
Merge branch 'main' of https://github.com/faaany/optimum-intel into i…
faaany May 26, 2024
48b205e
change to Ref
faaany May 26, 2024
404486a
Merge branch 'ipex-cpu' of https://github.com/faaany/optimum-intel in…
faaany May 26, 2024
4ea8a47
remove Ref
faaany May 27, 2024
1f98d6d
skip tests
jiqing-feng May 27, 2024
d3ce377
skip tests
jiqing-feng May 27, 2024
b2b93bb
skip testing without pkv
jiqing-feng May 27, 2024
ec0f641
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 27, 2024
64dcde4
add tests skip
jiqing-feng May 27, 2024
945f6b6
only llama2 with at least 64 head size support IAKV
jiqing-feng May 27, 2024
0733625
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 27, 2024
c8922f3
cannot assert same outputs cause do_sample=True
jiqing-feng May 27, 2024
0ff1d7b
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 27, 2024
2ddfa7a
rm tiny-llama model testing cause it not work for IAKV
jiqing-feng May 27, 2024
f4e887d
fix code style
jiqing-feng May 28, 2024
923e233
Merge branch 'rename' of https://github.com/jiqing-feng/optimum-intel…
faaany May 28, 2024
74f132e
refine docstring
faaany May 28, 2024
e130345
fix duplicted code
faaany May 30, 2024
14673da
refactor attention forward
faaany Jun 3, 2024
a2a969e
add use_cache for rope
faaany Jun 3, 2024
3abd790
use with and without cache
faaany Jun 3, 2024
82bd0c7
refine code
faaany Jun 3, 2024
de2cc43
add reference link
faaany Jun 4, 2024
1385f97
Merge branch 'main' into ipex-cpu
faaany Jun 6, 2024
752aba6
bug fix
faaany Jun 6, 2024
1ef8d56
use reshape
faaany Jun 6, 2024
5f5d205
Apply suggestions from code review
faaany Jun 6, 2024
22860f2
fix
faaany Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 4 additions & 20 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
Expand All @@ -23,8 +22,7 @@
from optimum.intel.utils.import_utils import is_ipex_version

from .modeling_utils import (
_IPEXLlamaDecoderLayerRef,
_llama_attn_forward,
_IPEXLlamaDecoderLayer,
_llama_layer_norm_forward,
_llama_model_forward,
)
Expand Down Expand Up @@ -62,26 +60,12 @@ def patch_op(m, target_m, new_op_name, new_op):


def _patch_llama_model(model):
if is_ipex_version("<", "2.5.0"):
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")

from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding

ipex_rope = RotaryEmbedding(
model.config.max_position_embeddings,
model.config.hidden_size // model.config.num_attention_heads,
model.config.rope_theta,
model.config.architectures[0],
)
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
if is_ipex_version("<", "2.3.0"):
raise ImportError("Only ipex version >= 2.3.0 supports llama model patching")

convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)

convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
return model


Expand Down
Loading
Loading