Skip to content

Commit

Permalink
chatglm export
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Feb 20, 2024
1 parent c1064fd commit bbdca54
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 3 deletions.
167 changes: 164 additions & 3 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from packaging import version
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.tasks import TasksManager
from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import (
DummyInputGenerator,
DummyPastKeyValuesGenerator,
DummyTextInputGenerator,
MistralDummyPastKeyValuesGenerator,
)
from optimum.utils.normalized_config import NormalizedTextConfig

from .model_patcher import MixtralModelPatcher
from .model_patcher import ChatGLMModelPatcher, MixtralModelPatcher


if TYPE_CHECKING:
Expand Down Expand Up @@ -70,6 +76,161 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class ChatGLM2DummyTextInputGenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = {
"input_ids",
"attention_mask",
"token_type_ids",
"position_ids",
}

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
import torch

input = super().generate(input_name, framework, int_dtype, float_dtype)
if input_name == "attention_mask":
input = torch.ones(input.shape, dtype=input.dtype)
if input_name == "position_ids":
bs = input.shape[0]
input = torch.range(0, input.shape[1], dtype=input.dtype).repeat(bs, 1)
return input


class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.multi_query_group_num = normalized_config.multi_query_group_num
self.head_dim = self.hidden_size // self.num_attention_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_shape = (
self.sequence_length,
self.batch_size,
self.multi_query_group_num,
self.head_dim,
)
past_value_shape = (
self.sequence_length,
self.batch_size,
self.multi_query_group_num,
self.head_dim,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"])
class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers")
DUMMY_INPUT_GENERATOR_CLASSES = (ChatGLM2DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator
no_position_ids = False

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
if self.use_past_in_inputs and self.use_cache_branch is not False:
input_names.append("past_key_values")

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

# refer to https://github.com/huggingface/optimum/pull/764
cond1 = self.use_past_in_inputs
cond2 = self.PAD_ATTENTION_MASK_TO_PAST
cond3 = self.use_cache_branch is not False
cond4 = "attention_mask" in dummy_inputs
if cond1 and cond2 and cond3 and cond4:
# Obtain the past sequence length from the value instead of the key (Bloom).
past_length = dummy_inputs["past_key_values"][0][1].shape[0]
for k, v in dummy_inputs.items():
if k not in ["attention_mask", "past_key_values"]:
dummy_inputs[k] = v[:, -1:]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)

return dummy_inputs

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
if not self.no_position_ids and self.task == "text-generation":
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
Args:
inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill.
direction (`str`):
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
output mapping, this is important for axes naming.
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"])
class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
Expand Down
125 changes: 125 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import logging as log
import types
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import is_tf_available

from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
from optimum.intel.utils.import_utils import (
Expand All @@ -27,6 +30,15 @@
)


if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel

from optimum.exporters.onnx.config import OnnxConfig

if is_tf_available():
from transformers.modeling_tf_utils import TFPreTrainedModel


def patch_model_with_bettertransformer(model):
# check that the model has not yet been pathced
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
Expand Down Expand Up @@ -107,3 +119,116 @@ def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward


def _chatglm_transformer_forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)

if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
elif past_key_values is not None:
full_attention_mask = torch.ones(
batch_size,
seq_length,
seq_length,
device=input_ids.device,
dtype=torch.float,
) * float("-inf")
full_attention_mask.triu_(diagonal=1)
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[0]
if past_length:
full_attention_mask = torch.cat(
(
torch.zeros(batch_size, seq_length, past_length, device=input_ids.device),
full_attention_mask,
),
dim=-1,
)
full_attention_mask.unsqueeze_(1)

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


class ChatGLMModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

self.original_chatglm_transformer_forward = model.transformer.forward

def __enter__(self):
super().__enter__()
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.transformer.forward = self.original_chatglm_transformer_forward

0 comments on commit bbdca54

Please sign in to comment.