Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support assisted decoding in ipex 2.4 #823

Merged
merged 19 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

import intel_extension_for_pytorch as ipex
import torch
import transformers
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
from transformers import (
AutoConfig,
AutoModel,
Expand All @@ -43,6 +43,7 @@
is_torch_xpu_available,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.generation.candidate_generator import _crop_past_key_values
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME
Expand Down Expand Up @@ -86,10 +87,37 @@ def _is_patched_with_ipex(model, task):


def _prepare_inputs_for_ipex_model(model, task, use_cache):
if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex(model, task):
return get_dummy_input(model, return_dict=True)
else:
return prepare_jit_inputs(model, task, use_cache)
sample_inputs = prepare_jit_inputs(model, task, use_cache)
if (
task in _IPEX_EXPORTED_GENERATION_TASKS
and _is_patched_with_ipex(model, task)
and "past_key_values" in sample_inputs
):
# Only consider llama for now
assert len(sample_inputs["past_key_values"][0][0].shape) == 4
max_position = model.config.max_position_embeddings
batch_size = sample_inputs["input_ids"].shape[0]
past_length = sample_inputs["past_key_values"][0][0].shape[2]
num_attn = sample_inputs["past_key_values"][0][0].shape[1]
d_k = sample_inputs["past_key_values"][0][0].shape[-1]
dtype = sample_inputs["past_key_values"][0][0].dtype

num_layers = len(sample_inputs["past_key_values"])
beam_idx_tmp = torch.zeros((max_position, batch_size), dtype=torch.long).contiguous()
past_key_values = tuple(
[
(
torch.zeros(1, past_length, past_length, 1, dtype=torch.long).contiguous(),
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
torch.zeros(max_position, batch_size, num_attn, d_k, dtype=dtype).contiguous(),
torch.zeros(max_position, batch_size, num_attn, d_k, dtype=dtype).contiguous(),
beam_idx_tmp,
)
for i in range(num_layers)
]
)
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
sample_inputs["past_key_values"] = past_key_values

return sample_inputs


def ipex_jit_trace(model, task, use_cache):
Expand Down Expand Up @@ -472,6 +500,8 @@ def __init__(

if self._is_ipex_exported:
self._reorder_cache = _ipex_reorder_cache
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
else:
# Check if _reorder_cache is a static method
if isinstance(self.model_cls.__dict__["_reorder_cache"], staticmethod):
Expand Down Expand Up @@ -609,9 +639,9 @@ def _prepare_generation_config(
return generation_config, model_kwargs

def generate(self, *args, **kwargs):
if self._is_ipex_exported and kwargs.get("assistant_model", None):
if is_ipex_version("<", "2.5.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models if ipex < 2.5, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
return super().generate(*args, **kwargs)

Expand Down Expand Up @@ -693,3 +723,17 @@ def _ipex_reorder_cache(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel):
new_past_key_values = []
for i in range(len(past_key_values)):
pkv = []
pkv.append(past_key_values[i][0][:, :max_length, :max_length, :])
pkv += [past_key_values[i][_] for _ in range(1, 4)]
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
new_past_key_values.append(tuple(pkv))
new_past_key_values = tuple(new_past_key_values)
return new_past_key_values
else:
_crop_past_key_values(model, past_key_values, max_length)
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_pipeline(self, model_arch):
# High optimized model llama is not supported assisted decoding for now.
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_assisted_decoding(self, model_arch):
if model_arch == "llama2":
if model_arch == "llama2" and is_ipex_version("<", "2.5.0"):
return
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down
Loading