Skip to content

Commit

Permalink
add openvino export configs
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Feb 19, 2024
1 parent 8f7d016 commit b08610f
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 24 deletions.
5 changes: 5 additions & 0 deletions optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .__main__ import main_export
from .base import init_model_configs
from .convert import export, export_from_model, export_models, export_pytorch_via_onnx
from .model_configs import *
from .stateful import ensure_stateful_is_available, patch_stateful


init_model_configs()


__all__ = ["main_export", "export", "export_models"]
27 changes: 27 additions & 0 deletions optimum/exporters/openvino/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from optimum.exporters.tasks import TasksManager


def init_model_configs():
suppored_models = TasksManager._SUPPORTED_MODEL_TYPE
for model, export_configs in suppored_models.items():
if "onnx" not in export_configs:
continue
TasksManager._SUPPORTED_MODEL_TYPE[model]["openvino"] = deepcopy(
TasksManager._SUPPORTED_MODEL_TYPE[model]["onnx"]
)
37 changes: 16 additions & 21 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.exporters.utils import _get_submodels_and_export_configs
from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available
from optimum.utils.save_utils import maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
from ...intel.utils.import_utils import is_nncf_available
from .model_patcher import patch_model_with_bettertransformer
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
from .utils import (
Expand All @@ -48,13 +49,6 @@
)


if is_optimum_version(">=", "1.16.99"):
from optimum.exporters.onnx.utils import _get_submodels_and_onnx_configs

else:
from optimum.exporters.onnx.__main__ import _get_submodels_and_onnx_configs


UNSUPPORTED_TOKENIZER_CLASSES = (T5Tokenizer, T5TokenizerFast)


Expand Down Expand Up @@ -458,7 +452,7 @@ def ts_patched_forward(*args, **kwargs):


def export_models(
models_and_onnx_configs: Dict[
models_and_export_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
],
output_dir: Path,
Expand All @@ -475,7 +469,7 @@ def export_models(
Export the models to OpenVINO IR format
Args:
models_and_onnx_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]):
models_and_export_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]):
output_dir (Path): output directory for saving models
opset (Optional[int], optional, Default to None): ONNX export opset
output_names (Optional[List[str]], optional, Defaults to None): model output names
Expand Down Expand Up @@ -504,20 +498,20 @@ def export_models(
# TODO : modify compression_option to quantization_config
outputs = []

if output_names is not None and len(output_names) != len(models_and_onnx_configs):
if output_names is not None and len(output_names) != len(models_and_export_configs):
raise ValueError(
f"Provided custom names {output_names} for the export of {len(models_and_onnx_configs)} models. Please provide the same number of names as models to export."
f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export."
)

for i, model_name in enumerate(models_and_onnx_configs.keys()):
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
for i, model_name in enumerate(models_and_export_configs.keys()):
submodel, sub_export_config = models_and_export_configs[model_name]
output_name = output_names[i] if output_names is not None else Path(model_name + ".xml")
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
outputs.append(
export(
model=submodel,
config=sub_onnx_config,
config=sub_export_config,
output=output_path,
opset=opset,
device=device,
Expand Down Expand Up @@ -621,7 +615,7 @@ def export_from_model(
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs(
export_config, models_and_export_configs = _get_submodels_and_export_configs(
model=model,
task=task,
monolith=False,
Expand All @@ -633,6 +627,7 @@ def export_from_model(
model_kwargs=model_kwargs,
_variant="default",
legacy=False,
exporter="openvino",
)

if compression_option is None:
Expand Down Expand Up @@ -661,18 +656,18 @@ def export_from_model(
model_name_or_path = model.config._name_or_path
maybe_save_preprocessors(model_name_or_path, output)

files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_onnx_configs.keys()]
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]

else:
# save the subcomponent configuration
for model_name in models_and_onnx_configs:
subcomponent = models_and_onnx_configs[model_name][0]
for model_name in models_and_export_configs:
subcomponent = models_and_export_configs[model_name][0]
if hasattr(subcomponent, "save_config"):
subcomponent.save_config(output / model_name)
elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"):
subcomponent.config.save_pretrained(output / model_name)

files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs]
files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_export_configs]

# Saving the additional components needed to perform inference.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
Expand All @@ -692,7 +687,7 @@ def export_from_model(
model.save_config(output)

export_models(
models_and_onnx_configs=models_and_onnx_configs,
models_and_export_configs=models_and_export_configs,
output_dir=output,
output_names=files_subpaths,
input_shapes=input_shapes,
Expand Down
56 changes: 56 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.tasks import TasksManager
from optimum.utils.input_generators import DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator
from optimum.utils.normalized_config import NormalizedTextConfig


register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True)


@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"])
class BaichaunOpenVINOConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size"
)


@register_in_tasks_manager("jais", *["text-generation", "text-generation-with-past"])
class JaisOpenVINOConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_layers="n_layer", num_attention_heads="n_head", hidden_size="n_embd"
)


@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"])
class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"])
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
2 changes: 1 addition & 1 deletion optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _quantize_torchmodel(

model_type = self.model.config.model_type.replace("_", "-")
onnx_config_class = TasksManager.get_exporter_config_constructor(
exporter="onnx",
exporter="openvino",
model=self.model,
task=self.task,
model_type=model_type,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

INSTALL_REQUIRE = [
"torch>=1.11",
"optimum>=1.17.0",
"optimum @ git+https://github.com/eaidova/optimum.git@ea/move_model_preparation#egg=optimum",
"transformers>=4.26.0",
"datasets>=1.4.0",
"sentencepiece",
Expand Down Expand Up @@ -50,7 +50,7 @@
"onnx",
"onnxruntime",
"transformers>=4.36.0",
"optimum>=1.16.1",
"optimum @ git+https://github.com/eaidova/optimum.git@ea/move_model_preparation#egg=optimum"
],
"openvino-tokenizers": ["openvino-tokenizers[transformers]"],
"nncf": ["nncf>=2.8.1"],
Expand Down

0 comments on commit b08610f

Please sign in to comment.