Skip to content

Commit

Permalink
ensure that task support stateful
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 10, 2024
1 parent 27c0f0a commit 18caf84
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
3 changes: 2 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version
from .convert import export_models
from .stateful import ensure_export_task_support_stateful


if is_optimum_version(">=", "1.16.0"):
Expand Down Expand Up @@ -282,7 +283,7 @@ class StoreAttr(object):

synonyms_for_task = TasksManager.synonyms_for_task(task)
synonyms_for_task.add(task)
if stateful and "text-generation-with-past" not in synonyms_for_task:
if stateful and not ensure_export_task_support_stateful(task):
stateful = False

preprocessors = maybe_load_preprocessors(
Expand Down
7 changes: 4 additions & 3 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def export(
if "diffusers" in str(model.__class__) and not is_diffusers_available():
raise ImportError("The pip package `diffusers` is required to export stable diffusion models to ONNX.")

if stateful:
# This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used
stateful = ensure_stateful_is_available()

if is_torch_available() and isinstance(model, nn.Module):
return export_pytorch(
model,
Expand Down Expand Up @@ -497,9 +501,6 @@ def export_models(
Returns:
list of input_names and output_names from ONNX configuration
"""
if stateful:
# This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used
stateful = ensure_stateful_is_available()
outputs = []

if output_names is not None and len(output_names) != len(models_and_onnx_configs):
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import openvino as ov
from openvino.runtime import opset13
from optimum.exporters import TasksManager
from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version
from optimum.utils.normalized_config import NormalizedConfigManager

Expand Down Expand Up @@ -197,6 +198,12 @@ def ensure_stateful_is_available(warn=True):
return True


def ensure_export_task_support_stateful(task: str):
synonyms_for_task = TasksManager.synonyms_for_task(task)
synonyms_for_task.add(task)
return "text-generation-with-past" in synonyms_for_task


def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
"""
Apply stateful transformation to model to hiding key values inputs inside model.
Expand Down
8 changes: 7 additions & 1 deletion optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from optimum.quantization_base import OptimumQuantizer

from ...exporters.openvino import export, export_pytorch_via_onnx
from ...exporters.openvino.stateful import ensure_export_task_support_stateful
from ..utils.constant import _TASK_ALIASES
from .configuration import OVConfig
from .modeling_base import OVBaseModel
Expand Down Expand Up @@ -417,6 +418,8 @@ def _quantize_torchmodel(
onnx_config = onnx_config_class(
model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache
)
if model.config.use_cache:
task = "text-generation-with-past"
else:
onnx_config = onnx_config_class(model.config)

Expand All @@ -425,7 +428,10 @@ def _quantize_torchmodel(
export_fn = export if not quantization_config.save_onnx_model else export_pytorch_via_onnx
opset = min(onnx_config.DEFAULT_ONNX_OPSET, MAX_ONNX_OPSET)
opset = max(opset, MIN_ONNX_QDQ_OPSET)
_, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset)
kwargs = {}
if not quantization_config.save_onnx_model:
kwargs = {"stateful": ensure_export_task_support_stateful(task)}
_, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs)
if is_onnx:
# Load and save the compressed model
model = core.read_model(onnx_path)
Expand Down

0 comments on commit 18caf84

Please sign in to comment.