Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support assisted decoding in ipex 2.4 #823

Merged
merged 19 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions optimum/exporters/ipex/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

from optimum.exporters.onnx.model_configs import LlamaOnnxConfig
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator
from optimum.utils.normalized_config import NormalizedTextConfig


class IPEXDummyPastKeyValuesGenerator(MistralDummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = normalized_config.num_key_value_heads
self.max_position_embeddings = normalized_config.max_position_embeddings

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape_init = (1, self.sequence_length, self.sequence_length, 1)
shape_beam_idx_tmp = (self.max_position_embeddings, self.batch_size)
shape_kv = (
self.max_position_embeddings,
self.batch_size,
self.num_key_value_heads,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_int_tensor(shape_init, max_value=1, framework=framework).contiguous(),
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
self.random_int_tensor(shape_beam_idx_tmp, max_value=1, framework=framework).contiguous(),
)
for _ in range(self.num_layers)
]


class LlamaIPEXConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator


ipex_onnx_config = {"llama": LlamaIPEXConfig}
74 changes: 60 additions & 14 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

import intel_extension_for_pytorch as ipex
import torch
import transformers
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
from transformers import (
AutoConfig,
AutoModel,
Expand All @@ -43,6 +43,7 @@
is_torch_xpu_available,
)
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.generation.candidate_generator import _crop_past_key_values
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
from transformers.utils import WEIGHTS_NAME
Expand All @@ -51,12 +52,14 @@
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ...exporters.ipex.model_config import ipex_onnx_config
from ...exporters.ipex.model_patcher import (
_IPEX_EXPORTED_GENERATION_TASKS,
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_patch_model,
)
from ..generation.modeling import prepare_jit_inputs
from ..generation.modeling import get_float_type
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device

Expand Down Expand Up @@ -86,10 +89,23 @@ def _is_patched_with_ipex(model, task):


def _prepare_inputs_for_ipex_model(model, task, use_cache):
if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex(model, task):
return get_dummy_input(model, return_dict=True)
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config:
onnx_config_class = ipex_onnx_config[model.config.model_type]
else:
return prepare_jit_inputs(model, task, use_cache)
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.dtype)
if "text-generation" in task:
onnx_config = onnx_config_class(
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
)
else:
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


def ipex_jit_trace(model, task, use_cache):
Expand All @@ -103,11 +119,9 @@ def ipex_jit_trace(model, task, use_cache):
sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache)

model.config.return_dict = False

if "past_key_values" in sample_inputs:
model.config.use_cache = use_cache
if not use_cache:
sample_inputs.pop("past_key_values")
model.config.use_cache = use_cache
if "past_key_values" in sample_inputs and not use_cache:
sample_inputs.pop("past_key_values")
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved

# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
# Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
Expand Down Expand Up @@ -372,7 +386,7 @@ def _init_warmup(self):
# TODO : add warmup for IPEX exported model
if not self._is_ipex_exported:
use_cache = "past_key_values" in self.input_names
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache)
if self._device.type != "cpu":
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
for _ in range(2):
Expand Down Expand Up @@ -652,11 +666,30 @@ def _prepare_generation_config(
return generation_config, model_kwargs

def generate(self, *args, **kwargs):
if self._is_ipex_exported and kwargs.get("assistant_model", None):
if is_ipex_version("<", "2.5.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
f"Assisted decoding is not supported for patched models if ipex < 2.5, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
return super().generate(*args, **kwargs)
# Patch functions to support IAKV cache
if self._is_ipex_exported:
_patch_crop_past_key_values()
try:
result = super().generate(*args, **kwargs)
except Exception as e:
_unpatch_crop_past_key_values()
raise e
_unpatch_crop_past_key_values()
return result


def _patch_crop_past_key_values():
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be unpatch after export ?

Copy link
Collaborator Author

@jiqing-feng jiqing-feng Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed in Teams, this is the only way to enable all assisted decoding cases:

  1. transformers target model + ipex draft model
  2. ipex target model + transformers draft model
  3. ipex target model + ipex draft model

The _crop_past_key_values function is the same level as the model, we cannot do un-patch inside the generate function because it will run after generate, see here.

I have checked the model type inside the _ipex_crop_past_key_values. It only has impact on IPEX model, transformers model will go into the original function, so there is no risk even we don't unpatch.



def _unpatch_crop_past_key_values():
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
transformers.generation.utils._crop_past_key_values = _crop_past_key_values


def _ipex_prepare_inputs_for_generation(
Expand Down Expand Up @@ -736,3 +769,16 @@ def _ipex_reorder_cache(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)


def _ipex_crop_past_key_values(model, past_key_values, max_length):
if isinstance(model, IPEXModel):
new_past_key_values = []
for i in range(len(past_key_values)):
pkv = []
pkv.append(past_key_values[i][0][:, :max_length, :max_length, :])
pkv += [past_key_values[i][_] for _ in range(1, 4)]
jiqing-feng marked this conversation as resolved.
Show resolved Hide resolved
new_past_key_values.append(tuple(pkv))
new_past_key_values = tuple(new_past_key_values)
return new_past_key_values
return _crop_past_key_values(model, past_key_values, max_length)
5 changes: 2 additions & 3 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,10 @@ def test_pipeline(self, model_arch):
self.assertEqual(pipe.device, model.device)
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))

# High optimized model llama is not supported assisted decoding for now.
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_assisted_decoding(self, model_arch):
# Patched models are not support assisted decoding for now.
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
# Patched models are not support assisted decoding if ipex < 2.5.
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.5.0"):
return
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down
Loading