Skip to content

Commit

Permalink
enable attention mask and fix accuracy issue for chatglm
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 20, 2023
1 parent fae7802 commit ce6cdaf
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 6 deletions.
12 changes: 12 additions & 0 deletions optimum/exporters/openvino/dummy_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing import Optional, Tuple

import torch

from optimum.utils import (
DEFAULT_DUMMY_SHAPES,
DummyPastKeyValuesGenerator,
Expand All @@ -30,6 +32,16 @@ class ChatGLN2DummyTextInputGenerator(DummyTextInputGenerator):
"position_ids",
}

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
input = super().generate(input_name, framework, int_dtype, float_dtype)
if input_name == "attention_mask":
input = torch.ones((input.shape[0], input.shape[1] + 1), dtype=input.dtype)
# input[0] = 0
if input_name == "position_ids":
input = torch.range(0, input.shape[1] + 1, dtype=input.dtype).repeat(1, 1)
# input[0] = 0
return input


class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
Expand Down
1 change: 0 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
common_inputs.pop("attention_mask")
if not self.no_position_ids and self.task == "text-generation":
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

Expand Down
38 changes: 33 additions & 5 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import openvino
Expand All @@ -25,7 +25,7 @@
from openvino.runtime import Core, Tensor, Type
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput

from optimum.utils import NormalizedConfigManager

Expand Down Expand Up @@ -401,9 +401,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
Expand All @@ -413,6 +412,35 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"token_type_ids": None,
}

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)

# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)

# update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id += 1
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)

model_kwargs["is_first_forward"] = False
return model_kwargs

def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
Expand Down
38 changes: 38 additions & 0 deletions optimum/intel/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import types
from typing import Tuple

import torch
Expand Down Expand Up @@ -92,6 +93,40 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
return combined_attention_mask


@torch.jit.script_if_tracing
def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor):
if query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, is_causal=True
)
else:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer)
return context_layer


def _core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None:
context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer)
else:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)

return context_layer


def _patch_chatglm_core_attention_forward(model: "PreTrainedModel"):
for block in model.transformer.encoder.layers:
block.self_attention.core_attention.forward = types.MethodType(
_core_attention_forward, block.self_attention.core_attention
)


def patch_decoder_attention_mask(model: "PreTrainedModel"):
"""
Apply patch on decoder with past model forward to resolve first inference based on model architecture
Expand All @@ -108,4 +143,7 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"):
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}:
model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
elif model.config.model_type == "chatglm":
_patch_chatglm_core_attention_forward(model)

return model

0 comments on commit ce6cdaf

Please sign in to comment.