Skip to content

Commit

Permalink
make input dynamic and enable sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Dec 18, 2024
1 parent 0249b17 commit 6b9dc88
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
6 changes: 5 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ def main_export(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
if (
is_transformers_version(">=", "4.36")
and is_transformers_version("<=", "4.45.0")
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
):
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default that does not support load model on cpu
Expand Down
9 changes: 4 additions & 5 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
from openvino.tools.ovc import convert_model
from optimum.exporters import TasksManager
from optimum.exporters.utils import (
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
)
from optimum.exporters.utils import (
get_diffusion_models_for_export,
DECODER_NAME,
DECODER_WITH_PAST_NAME,
ENCODER_NAME,
_get_submodels_for_export_encoder_decoder,
get_diffusion_models_for_export,
)
from optimum.exporters.utils import (
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
)
from optimum.intel.utils.import_utils import (
_diffusers_version,
Expand All @@ -47,7 +47,6 @@
_torch_version,
_transformers_version,
compare_versions,
is_openvino_version,
is_openvino_tokenizers_version,
is_tokenizers_version,
is_transformers_version,
Expand Down
14 changes: 14 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,13 @@ def patch_model_for_export(
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
return common_inputs


@register_in_tasks_manager(
"t5",
Expand All @@ -2299,6 +2306,13 @@ def patch_model_for_export(
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
return common_inputs


@register_in_tasks_manager(
"mt5",
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from optimum.exporters.onnx.model_patcher import (
DecoderModelPatcher,
ModelPatcher,
override_arguments,
Seq2SeqModelPatcher,
override_arguments,
)
from optimum.intel.utils.import_utils import (
_openvino_version,
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_shape_of_ops(model: ov.Model):

def get_consumer_nodes(node):
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
return set(input.get_node() for input in consumer_inputs)
return {input.get_node() for input in consumer_inputs}


def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):
Expand Down

0 comments on commit 6b9dc88

Please sign in to comment.