diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 1030684234..a6d5c60ef2 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -28,6 +28,7 @@ from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version from .convert import export_models +from .stateful import ensure_export_task_support_stateful if is_optimum_version(">=", "1.16.0"): @@ -282,7 +283,7 @@ class StoreAttr(object): synonyms_for_task = TasksManager.synonyms_for_task(task) synonyms_for_task.add(task) - if stateful and "text-generation-with-past" not in synonyms_for_task: + if stateful and not ensure_export_task_support_stateful(task): stateful = False preprocessors = maybe_load_preprocessors( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index ee9c24f88e..e0571b5338 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -144,6 +144,10 @@ def export( if "diffusers" in str(model.__class__) and not is_diffusers_available(): raise ImportError("The pip package `diffusers` is required to export stable diffusion models to ONNX.") + if stateful: + # This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used + stateful = ensure_stateful_is_available() + if is_torch_available() and isinstance(model, nn.Module): return export_pytorch( model, @@ -497,9 +501,6 @@ def export_models( Returns: list of input_names and output_names from ONNX configuration """ - if stateful: - # This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used - stateful = ensure_stateful_is_available() outputs = [] if output_names is not None and len(output_names) != len(models_and_onnx_configs): diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 0d34263a80..4fe038fabb 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -20,6 +20,7 @@ import openvino as ov from openvino.runtime import opset13 +from optimum.exporters import TasksManager from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version from optimum.utils.normalized_config import NormalizedConfigManager @@ -197,6 +198,12 @@ def ensure_stateful_is_available(warn=True): return True +def ensure_export_task_support_stateful(task: str): + synonyms_for_task = TasksManager.synonyms_for_task(task) + synonyms_for_task.add(task) + return "text-generation-with-past" in synonyms_for_task + + def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): """ Apply stateful transformation to model to hiding key values inputs inside model. diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 3f5d270da9..63fac8df6d 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -38,6 +38,7 @@ from optimum.quantization_base import OptimumQuantizer from ...exporters.openvino import export, export_pytorch_via_onnx +from ...exporters.openvino.stateful import ensure_export_task_support_stateful from ..utils.constant import _TASK_ALIASES from .configuration import OVConfig from .modeling_base import OVBaseModel @@ -417,6 +418,8 @@ def _quantize_torchmodel( onnx_config = onnx_config_class( model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache ) + if model.config.use_cache: + task = "text-generation-with-past" else: onnx_config = onnx_config_class(model.config) @@ -425,7 +428,10 @@ def _quantize_torchmodel( export_fn = export if not quantization_config.save_onnx_model else export_pytorch_via_onnx opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET) opset = max(opset, MIN_ONNX_QDQ_OPSET) - _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset) + kwargs = {} + if not quantization_config.save_onnx_model: + kwargs = {"stateful": ensure_export_task_support_stateful(task)} + _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs) if is_onnx: # Load and save the compressed model model = core.read_model(onnx_path)