Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable OpenVINO export of loaded model #557

Merged
merged 27 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
python -m pip install --upgrade pip
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install git+https://github.com/huggingface/optimum.git
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
pip install .[openvino,openvino-tokenizers,nncf,tests,diffusers]
- name: Test with Pytest
run: |
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .__main__ import main_export
from .convert import export, export_models, export_pytorch_via_onnx
from .convert import export, export_from_model, export_models, export_pytorch_via_onnx
from .stateful import ensure_stateful_is_available, patch_stateful


Expand Down
297 changes: 108 additions & 189 deletions optimum/exporters/openvino/__main__.py

Large diffs are not rendered by default.

191 changes: 184 additions & 7 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import gc
import inspect
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from transformers import T5Tokenizer, T5TokenizerFast
from transformers.utils import is_tf_available, is_torch_available
Expand All @@ -26,17 +27,20 @@
from openvino.runtime.exceptions import OVTypeError
from openvino.runtime.utils.types import get_element_type
from openvino.tools.ovc import convert_model
from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher
from optimum.utils import is_diffusers_available
from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available
from optimum.utils.save_utils import maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version
from .model_patcher import patch_model_with_bettertransformer
from .stateful import ensure_stateful_is_available, patch_stateful
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
from .utils import (
_MAX_UNCOMPRESSED_SIZE,
OV_XML_FILE_NAME,
clear_class_registry,
flattenize_inputs,
Expand All @@ -45,6 +49,16 @@
)


if is_optimum_version(">=", "1.16.99"):
from optimum.exporters.onnx.utils import _get_submodels_and_onnx_configs

else:
from optimum.exporters.onnx.__main__ import _get_submodels_and_onnx_configs


UNSUPPORTED_TOKENIZER_CLASSES = (T5Tokenizer, T5TokenizerFast)


logger = logging.getLogger(__name__)

if is_torch_available():
Expand Down Expand Up @@ -540,10 +554,173 @@ def export_models(
return outputs


UNSUPPORTED_TOKENIZER_CLASSES = (
T5Tokenizer,
T5TokenizerFast,
)
def export_from_model(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
output: Union[str, Path],
task: Optional[str] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: bool = True,
opset: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
preprocessors: List = None,
device: str = "cpu",
**kwargs_shapes,
):
if (
compression_option is not None
and compression_option != "fp16"
and compression_option != "fp32"
and not is_nncf_available()
):
raise ImportError(
f"Compression of the weights to {compression_option} requires nncf, please install it with `pip install nncf`"
)

model_kwargs = model_kwargs or {}
library_name = TasksManager._infer_library_from_model(model)
TasksManager.standardize_model_attributes(model, library_name)

if hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")

custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE

if task is not None:
task = TasksManager.map_from_synonym(task)
else:
try:
task = TasksManager._infer_task_from_model_or_model_class(model=model)
except (ValueError, KeyError) as e:
raise RuntimeError(
f"The model task could not be automatically inferred in `onnx_export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

if (
not custom_architecture
and library_name != "diffusers"
and task + "-with-past"
in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx", library_name=library_name)
):
# -with-past is the default.
task = task + "-with-past"

logger.info(f"Automatic task detection to: {task}.")

stateful = stateful and ensure_export_task_support_stateful(task)

# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_onnx_configs is None:
raise ValueError(
f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export."
)

if task.startswith("text-generation") and model.config.is_encoder_decoder:
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model,"
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TRANSFORMERS_TASKS_TO_MODEL_LOADERS`."
)
if library_name != "diffusers" and model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE:
raise ValueError(
f"{model_type} is not supported yet. Only {list(TasksManager._SUPPORTED_CLI_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)

output = Path(output)
if not output.exists():
output.mkdir(parents=True)

# Get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs(
model=model,
task=task,
monolith=False,
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
library_name=library_name,
model_kwargs=model_kwargs,
_variant="default",
legacy=False,
)

if compression_option is None:
# TODO : sentence transformers compatibility
num_parameters = model.num_parameters() if library_name != "diffusers" else model.unet.num_parameters()
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
compression_option = "int8"
logger.info("The model weights will be quantized to int8.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)

if library_name != "diffusers":
# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(output)
generation_config = getattr(model, "generation_config", None)
if generation_config is not None:
generation_config.save_pretrained(output)

model_name_or_path = model.config._name_or_path
maybe_save_preprocessors(model_name_or_path, output)

files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_onnx_configs.keys()]

else:
# save the subcomponent configuration
for model_name in models_and_onnx_configs:
subcomponent = models_and_onnx_configs[model_name][0]
if hasattr(subcomponent, "save_config"):
subcomponent.save_config(output / model_name)
elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"):
subcomponent.config.save_pretrained(output / model_name)

files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs]

# Saving the additional components needed to perform inference.
model.scheduler.save_pretrained(output.joinpath("scheduler"))

feature_extractor = getattr(model, "feature_extractor", None)
if feature_extractor is not None:
feature_extractor.save_pretrained(output.joinpath("feature_extractor"))

tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
tokenizer.save_pretrained(output.joinpath("tokenizer"))

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

model.save_config(output)

export_models(
models_and_onnx_configs=models_and_onnx_configs,
output_dir=output,
output_names=files_subpaths,
input_shapes=input_shapes,
device=device,
compression_option=compression_option,
compression_ratio=compression_ratio,
stateful=stateful,
opset=opset,
model_kwargs=model_kwargs,
)


def export_tokenizer(
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


OV_XML_FILE_NAME = "openvino_model.xml"
_MAX_UNCOMPRESSED_SIZE = 1e9


def is_torch_model(model: Union["PreTrainedModel", "ModelMixin"]):
Expand Down
14 changes: 14 additions & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from optimum.intel.ipex.modeling_base import (
IPEXModel,
IPEXModelForAudioClassification,
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from ..utils.import_utils import is_diffusers_available, is_nncf_available
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from pathlib import Path
Expand Down
14 changes: 14 additions & 0 deletions optimum/intel/openvino/modeling_timm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import OrderedDict
from typing import Dict, List, Optional, Union
Expand Down
14 changes: 14 additions & 0 deletions optimum/intel/openvino/training_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from transformers import TrainingArguments
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

INSTALL_REQUIRE = [
"torch>=1.11",
"optimum>=1.14.0",
"optimum>=1.14.0", # TODO : 1.17.0
"transformers>=4.26.0",
"datasets>=1.4.0",
"sentencepiece",
Expand All @@ -33,6 +33,7 @@
"rjieba",
"timm",
"invisible-watermark>=0.2.0",
"auto-gptq",
]

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"]
Expand Down
15 changes: 10 additions & 5 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"gpt_neo",
"gpt_neox",
"llama",
# "marian", # TODO : enable it back with openvino 2023.3.0
# "mistral",
"llama_gptq",
"marian",
"mistral",
"mpt",
"opt",
"pegasus",
Expand All @@ -494,6 +495,10 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]

if "gptq" in model_arch:
self.skipTest("Unsupported GPTQ model")

set_seed(SEED)
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
self.assertIsInstance(ov_model.config, PretrainedConfig)
Expand Down Expand Up @@ -1031,7 +1036,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = OVModelForCTC.from_pretrained(MODEL_NAMES["t5"], export=True)

self.assertIn("Unrecognized configuration class", str(context.exception))
self.assertIn("only supports the tasks", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down Expand Up @@ -1083,7 +1088,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = OVModelForAudioXVector.from_pretrained(MODEL_NAMES["t5"], export=True)

self.assertIn("Unrecognized configuration class", str(context.exception))
self.assertIn("only supports the tasks", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down Expand Up @@ -1137,7 +1142,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = OVModelForAudioFrameClassification.from_pretrained(MODEL_NAMES["t5"], export=True)

self.assertIn("Unrecognized configuration class", str(context.exception))
self.assertIn("only supports the tasks", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"levit": "hf-internal-testing/tiny-random-LevitModel",
"longt5": "hf-internal-testing/tiny-random-longt5",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"opt125m": "facebook/opt-125m",
Expand Down
Loading