Skip to content

Commit

Permalink
qwen
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Mar 12, 2024
1 parent 53686cf commit 4854ce1
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 3 deletions.
120 changes: 119 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, GemmaModelPatcher, MixtralModelPatcher
from optimum.exporters.tasks import TasksManager
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import (
Expand All @@ -31,6 +30,8 @@
)
from optimum.utils.normalized_config import NormalizedTextConfig

from .model_patcher import ChatGLMModelPatcher, GemmaModelPatcher, MixtralModelPatcher, QwenModelPatcher


def init_model_configs():
supported_model_types = [
Expand Down Expand Up @@ -268,3 +269,120 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GemmaModelPatcher(self, model, model_kwargs=model_kwargs)


class QwenDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.kv_channels = normalized_config.kv_channels

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels)
past_value_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels)
return [
(
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


@register_in_tasks_manager("qwen", *["text-generation", "text-generation-with-past"])
class QwenOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
)
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, QwenDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = QwenDummyPastKeyValuesGenerator
no_position_ids = False

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
if self.use_past_in_inputs and self.use_cache_branch is not False:
input_names.append("past_key_values")

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

# refer to https://github.com/huggingface/optimum/pull/764
if (
self.use_past_in_inputs
and self.PAD_ATTENTION_MASK_TO_PAST
and self.use_cache_branch is not False
and "attention_mask" in dummy_inputs
):
# Obtain the past sequence length from the value instead of the key (Bloom). Qwen has seq_len in 1 dim instead of -2
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[1]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_present_length,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)

return dummy_inputs

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
Args:
inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill.
direction (`str`):
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
output mapping, this is important for axes naming.
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 1: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 1: decoder_sequence_name}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return QwenModelPatcher(self, model, model_kwargs=model_kwargs)
200 changes: 199 additions & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging as log
import types
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -279,3 +279,201 @@ def __enter__(self):
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)


SUPPORT_SDPA = is_torch_version(">", "2.1.0")


def _qwen_rotate_half(x):
from einops import rearrange

x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)


def _qwen_apply_rotary_pos_emb(t, freqs):
cos, sin = freqs
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * cos) + (_qwen_rotate_half(t_) * sin)
return torch.cat((t_, t_pass_), dim=-1).type_as(t)


def _qwen_quantize_cache_v(fdata, bits, qmax, qmin):
# b, s, head, h-dim->b, head, s, h-dim
qtype = torch.uint8
device = fdata.device
shape = fdata.shape

fdata_cal = torch.flatten(fdata, 2)
fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
# Compute params
if qmax.device != fmax.device:
qmax = qmax.to(device)
qmin = qmin.to(device)
scale = (fmax - fmin) / (qmax - qmin)
zero = qmin - fmin / scale
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
# Quantize
res_data = fdata / scale + zero
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
return qdata.contiguous(), scale, zero


def _qwen_attention_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)

query, key, value = mixed_x_layer.split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if rotary_pos_emb_list is not None:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = _qwen_apply_rotary_pos_emb(query, q_pos_emb)
key = _qwen_apply_rotary_pos_emb(key, k_pos_emb)
else:
query_list = []
key_list = []
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [_qwen_apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
key_list += [_qwen_apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)

if self.use_cache_quantization:
key = _qwen_quantize_cache_v(key.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax)
value = _qwen_quantize_cache_v(value.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax)

if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
if self.use_cache_quantization:
# use_cache_quantization:
# present=((q_key,key_scale,key_zero_point),
# (q_value,value_scale,value_zero_point))
key = (
torch.cat((past_key[0], key[0]), dim=2),
torch.cat((past_key[1], key[1]), dim=2),
torch.cat((past_key[2], key[2]), dim=2),
)
value = (
torch.cat((past_value[0], value[0]), dim=2),
torch.cat((past_value[1], value[1]), dim=2),
torch.cat((past_value[2], value[2]), dim=2),
)
else:
# not use_cache_quantization:
# present=(key,value)
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)

if use_cache:
present = (key, value)
else:
present = None

if self.use_logn_attn and not self.training:
if self.use_cache_quantization:
seq_start = key[0].size(2) - query.size(1)
seq_end = key[0].size(2)
else:
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query)

if self.use_flash_attn and not self.is_fp32 and query.is_cuda:
q, k, v = query, key, value
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
else:
registered_causal_mask = torch.tril(
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
).view(1, 1, key.size(1), key.size(1))
query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)

if not self.use_cache_quantization and SUPPORT_SDPA:
causal_mask = registered_causal_mask[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)]
if attention_mask is not None:
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1).masked_fill(
~causal_mask, torch.finfo(query.dtype).min
)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
attn_weight = None
else:
attn_output, attn_weight = self._attn(query, key, value, registered_causal_mask, attention_mask, head_mask)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)

attn_output = self.c_proj(context_layer)

outputs = (attn_output, present)
if output_attentions:
if self.use_flash_attn and not self.is_fp32:
raise ValueError("Cannot output attentions while using flash-attn")
else:
outputs += (attn_weight,)

return outputs


class QwenModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

self.original_fp16 = model.config.fp16
self.original_bf16 = model.config.bf16
model.config.bf16 = False
model.config.fp16 = False
if self.original_fp16 or self.original_bf16:
model.to(torch.float32)
model.transformer.rotary_emb(2048)

def __enter__(self):
super().__enter__()
for block in self._model.transformer.h:
block.attn._orig_forward = block.attn.forward
block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.transformer.h:
block.attn.forward = block.attn._orig_forward
self._model.config.bf16 = self.original_bf16
self._model.config.fp16 = self.original_fp16
5 changes: 4 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,12 +495,13 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"mpt",
"opt",
"pegasus",
"qwen",
"qwen2",
"stablelm",
)
GENERATION_LENGTH = 100
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais")
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen")

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down Expand Up @@ -531,6 +532,8 @@ def test_compare_to_transformers(self, model_arch):
)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
if model_arch == "qwen":
transformers_model.to(torch.float32)
tokens = tokenizer(
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
)
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"pegasus": "hf-internal-testing/tiny-random-pegasus",
"pix2struct": "fxmarty/pix2struct-tiny-random",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen": "katuni4ka/tiny-random-qwen",
"qwen2": "Qwen/Qwen1.5-0.5B",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-roberta",
Expand Down

0 comments on commit 4854ce1

Please sign in to comment.