Skip to content

Commit

Permalink
gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Mar 1, 2024
1 parent 02889ac commit 1b29066
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 19 deletions.
37 changes: 28 additions & 9 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, MixtralModelPatcher
from optimum.exporters.onnx.model_configs import GemmaOnnxConfig
from optimum.exporters.openvino.model_patcher import ChatGLMModelPatcher, GemmaModelPatcher, MixtralModelPatcher
from optimum.exporters.tasks import TasksManager
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import (
Expand Down Expand Up @@ -65,23 +66,23 @@ def init_model_configs():
register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True)


@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"])
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers")
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
)


@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"])
@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"], library_name="transformers")
class JaisOpenVINOConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_layers="n_layer", num_attention_heads="n_head", hidden_size="n_embd"
)


@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"])
@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"], library_name="transformers")
class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

Expand All @@ -90,7 +91,7 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"])
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers")
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

Expand All @@ -99,7 +100,7 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"])
@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers")
class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

Expand Down Expand Up @@ -128,7 +129,7 @@ def __init__(
random_sequence_length_range=random_sequence_length_range,
)
self.multi_query_group_num = normalized_config.multi_query_group_num
self.head_dim = self.hidden_size // self.num_attention_heads
self.head_dim = normalized_config.kv_channels

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
past_key_shape = (
Expand All @@ -152,7 +153,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
]


@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"])
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"], library_name="transformers")
class ChatGLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers")
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator)
Expand Down Expand Up @@ -232,7 +233,7 @@ def patch_model_for_export(
return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"])
@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"], library_name="transformers")
class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
MIN_TRANSFORMERS_VERSION = version.parse("4.34.99")
Expand All @@ -249,3 +250,21 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MixtralModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"gemma",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class GemmaOpenVINOConfig(GemmaOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GemmaModelPatcher(self, model, model_kwargs=model_kwargs)
50 changes: 49 additions & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -214,6 +214,34 @@ def _chatglm_transformer_forward(
)


@torch.jit.script_if_tracing
def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor):
mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype)
if query_layer.shape[2] == key_layer.shape[2]:
tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1)
mask.masked_fill_(tmp_mask, float("-inf"))

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attn_mask=mask
)
return context_layer


def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask):
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None:
context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer)
else:
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)

return context_layer


class ChatGLMModelPatcher(DecoderModelPatcher):
def __init__(
self,
Expand All @@ -228,7 +256,27 @@ def __init__(
def __enter__(self):
super().__enter__()
self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer)
for block in self._model.transformer.encoder.layers:
block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward
block.self_attention.core_attention.forward = types.MethodType(
_chatglm2_core_attention_forward, block.self_attention.core_attention
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.transformer.forward = self.original_chatglm_transformer_forward
for block in self._model.transformer.encoder.layers:
block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward


class GemmaModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

# init inv_freq for torchscript tracing
for layer in self._model.model.layers:
if layer.self_attn.rotary_emb.inv_freq is None:
rotary_emb = layer.self_attn.rotary_emb
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)
16 changes: 7 additions & 9 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from transformers.onnx.utils import get_preprocessor
from utils_tests import MODEL_NAMES

from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.intel import (
OVModelForAudioClassification,
OVModelForAudioFrameClassification,
Expand Down Expand Up @@ -481,6 +480,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"chatglm",
"codegen",
# "data2vec-text", # TODO : enable when enabled in exporters
"gemma",
"gpt2",
"gpt_neo",
"gpt_neox",
Expand All @@ -502,10 +502,13 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
not_stateful = ["gpt_bigcode", "llama"]
not_stateful = ["gpt_bigcode"]
if is_openvino_version("<", "2024.0"):
not_stateful.append("mixtral")

if is_openvino_version("<", "2024.1"):
not_stateful.extend(["llama", "gemma"])

if "gptq" in model_arch:
self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM")

Expand All @@ -528,11 +531,7 @@ def test_compare_to_transformers(self, model_arch):
tokens = tokenizer(
"This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
)
position_ids = None
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
input_shape = tokens["input_ids"].shape
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
ov_outputs = ov_model(**tokens, position_ids=position_ids)
ov_outputs = ov_model(**tokens)

self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)
Expand All @@ -542,12 +541,11 @@ def test_compare_to_transformers(self, model_arch):
self.assertEqual(ov_model.stateful, is_stateful)
if is_stateful:
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)

with torch.no_grad():
transformers_outputs = transformers_model(**tokens)

# Compare tensor outputs
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4))
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=1e-4))
del transformers_model
del ov_model
gc.collect()
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"convnext": "hf-internal-testing/tiny-random-convnext",
"distilbert": "hf-internal-testing/tiny-random-distilbert",
"electra": "hf-internal-testing/tiny-random-electra",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
Expand Down

0 comments on commit 1b29066

Please sign in to comment.