Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support assisted decoding in ipex 2.4 #823

Merged
merged 19 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions optimum/exporters/ipex/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

from optimum.exporters.onnx.model_configs import (
FalconOnnxConfig,
GPT2OnnxConfig,
LlamaOnnxConfig,
)
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import DummyPastKeyValuesGenerator, DummyTextInputGenerator
from optimum.utils.normalized_config import NormalizedTextConfig


DEFAULT_DUMMY_SHAPES["batch_size"] = 1
echarlaix marked this conversation as resolved.
Show resolved Hide resolved


class IPEXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = getattr(normalized_config, "num_key_value_heads", 1)
self.max_position_embeddings = normalized_config.max_position_embeddings

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape_init = (1, self.sequence_length, self.sequence_length, 1)
shape_beam_idx_tmp = (self.max_position_embeddings, self.batch_size)
shape_kv = (
self.max_position_embeddings,
self.batch_size,
self.num_key_value_heads,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_int_tensor(shape_init, max_value=1, framework=framework).contiguous(),
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
self.random_int_tensor(shape_beam_idx_tmp, max_value=1, framework=framework).contiguous(),
)
for _ in range(self.num_layers)
]


class IPEXDummyTextInputGenerator(DummyTextInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
super().__init__(task, normalized_config, batch_size, **kwargs)


class LlamaIPEXConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


class FalconIPEXConfig(FalconOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


class GPT2IPEXConfig(GPT2OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


ipex_onnx_config = {"llama": LlamaIPEXConfig, "falcon": FalconIPEXConfig, "gpt2": GPT2IPEXConfig}
83 changes: 69 additions & 14 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

import intel_extension_for_pytorch as ipex
import torch
import transformers
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
from transformers import (
AutoConfig,
AutoModel,
Expand All @@ -43,20 +43,24 @@
is_torch_xpu_available,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.generation.candidate_generator import _crop_past_key_values
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.exporters.tasks import make_backend_config_constructor_for_task
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ...exporters.ipex.model_config import ipex_onnx_config
from ...exporters.ipex.model_patcher import (
_IPEX_EXPORTED_GENERATION_TASKS,
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_patch_model,
)
from ..generation.modeling import prepare_jit_inputs
from ..generation.modeling import get_float_type
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device

Expand Down Expand Up @@ -86,10 +90,35 @@ def _is_patched_with_ipex(model, task):


def _prepare_inputs_for_ipex_model(model, task, use_cache):
if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex(model, task):
return get_dummy_input(model, return_dict=True)
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config:
onnx_config_class = make_backend_config_constructor_for_task(
ipex_onnx_config[model.config.model_type], task=task
)
else:
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.dtype)
if "text-generation" in task:
onnx_config = onnx_config_class(
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
)
else:
return prepare_jit_inputs(model, task, use_cache)
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")

# Check attention_mask shape
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config and use_cache:
past_len = dummy_inputs["past_key_values"][0][0].shape[-2]
input_len = dummy_inputs["input_ids"].shape[-1]
attention_len = dummy_inputs["attention_mask"].shape[-1]
if attention_len != input_len + past_len:
dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to(
dummy_inputs["input_ids"].dtype
)

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


def ipex_jit_trace(model, task, use_cache):
Expand All @@ -103,11 +132,7 @@ def ipex_jit_trace(model, task, use_cache):
sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache)

model.config.return_dict = False

if "past_key_values" in sample_inputs:
model.config.use_cache = use_cache
if not use_cache:
sample_inputs.pop("past_key_values")
model.config.use_cache = use_cache

# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
# Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
Expand Down Expand Up @@ -372,7 +397,7 @@ def _init_warmup(self):
# TODO : add warmup for IPEX exported model
if not self._is_ipex_exported:
use_cache = "past_key_values" in self.input_names
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache)
if self._device.type != "cpu":
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
for _ in range(2):
Expand Down Expand Up @@ -652,11 +677,28 @@ def _prepare_generation_config(
return generation_config, model_kwargs

def generate(self, *args, **kwargs):
if self._is_ipex_exported and kwargs.get("assistant_model", None):
if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
return super().generate(*args, **kwargs)
# Patch functions to support IAKV cache
if self._is_ipex_exported and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
elif self._is_ipex_exported:
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values

try:
result = super().generate(*args, **kwargs)
except Exception as e:
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
raise e

if self._is_ipex_exported and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values

return result


def _ipex_prepare_inputs_for_generation(
Expand Down Expand Up @@ -736,3 +778,16 @@ def _ipex_reorder_cache(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
new_past_key_values = []
for i in range(len(past_key_values)):
pkv = []
pkv.append(past_key_values[i][0][:, :max_length, :max_length, :])
pkv += [past_key_values[i][_] for _ in range(1, 4)]
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
new_past_key_values.append(tuple(pkv))
new_past_key_values = tuple(new_past_key_values)
return new_past_key_values
return _crop_past_key_values(model, past_key_values, max_length)
9 changes: 6 additions & 3 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,10 @@ def test_pipeline(self, model_arch):
self.assertEqual(pipe.device, model.device)
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))

# High optimized model llama is not supported assisted decoding for now.
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_assisted_decoding(self, model_arch):
# Patched models are not support assisted decoding for now.
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
# Patched models are not support assisted decoding if ipex < 2.5.
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.4.0"):
return
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand All @@ -296,11 +295,15 @@ def test_assisted_decoding(self, model_arch):
ipex_output_assisted = ipex_model.generate(
**tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4
)
ipex_output_assisted_2 = ipex_model.generate(
**tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4
)
transformers_output = transformers_model.generate(**tokens, do_sample=False, max_new_tokens=4)
transformers_output_assisted = transformers_model.generate(
**tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4
)
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted))
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted_2))
self.assertTrue(torch.equal(transformers_output, transformers_output_assisted))

@parameterized.expand(
Expand Down
Loading