Skip to content

Commit

Permalink
Add support export for new architectures (#720)
Browse files Browse the repository at this point in the history
* update codegen config for support codegen2

* add support DBRX

* add qwen2moe support

* fix test models

* buichuan sdpa

* apply review comments

* Apply suggestions from code review

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

---------

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
  • Loading branch information
eaidova and echarlaix authored May 21, 2024
1 parent 7a929e8 commit c69fe32
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 1 deletion.
47 changes: 47 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.model_configs import (
CodeGenOnnxConfig,
FalconOnnxConfig,
GemmaOnnxConfig,
LlamaOnnxConfig,
Expand All @@ -44,6 +45,8 @@
AquilaModelPatcher,
BaichuanModelPatcher,
ChatGLMModelPatcher,
CodeGenModelPatcher,
DBRXModelPatcher,
GemmaModelPatcher,
InternLM2Patcher,
InternLMModelPatcher,
Expand Down Expand Up @@ -112,6 +115,15 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers")
class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers")
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
Expand Down Expand Up @@ -738,3 +750,38 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return InternLMModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"codegen",
*["feature-extraction", "feature-extraction-with-past", "text-generation", "text-generation-with-past"],
library_name="transformers",
)
class CodeGenOpenVINOConfig(CodeGenOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"dbrx",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class DBRXOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_attention_heads="n_heads",
hidden_size="d_model",
num_layers="n_layers",
num_key_value_heads="attn_config.kv_n_heads",
allow_new=True,
)
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return DBRXModelPatcher(self, model, model_kwargs=model_kwargs)
Loading

0 comments on commit c69fe32

Please sign in to comment.