Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stateful seq2seq models #779

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,11 @@ def main_export(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default that does not support load model on cpu
if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES:
loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type]
# if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
# loading_kwargs["attn_implementation"] = "eager"
# there are some difference between remote and in library representation of past key values for some models,
# for avoiding confusion we disable remote code for them
if (
Expand Down
72 changes: 65 additions & 7 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
from optimum.exporters.utils import (
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
)
from optimum.exporters.utils import (
get_diffusion_models_for_export,
DECODER_NAME,
DECODER_WITH_PAST_NAME,
ENCODER_NAME,
_get_submodels_for_export_encoder_decoder,
)
from optimum.intel.utils.import_utils import (
_diffusers_version,
Expand Down Expand Up @@ -534,6 +535,8 @@ def export_models(
f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export."
)

if not isinstance(stateful, (list, tuple)):
stateful = [stateful] * len(models_and_export_configs)
for i, model_name in enumerate(models_and_export_configs.keys()):
submodel, sub_export_config = models_and_export_configs[model_name]
output_name = output_names[i] if output_names is not None else Path(model_name + ".xml")
Expand All @@ -549,9 +552,9 @@ def export_models(
input_shapes=input_shapes,
model_kwargs=model_kwargs,
ov_config=ov_config,
stateful=stateful[i] if isinstance(stateful, (list, tuple)) else stateful,
patch_16bit_model=patch_16bit_model,
library_name=library_name,
stateful=stateful[i],
)
)

Expand Down Expand Up @@ -613,9 +616,8 @@ def export_from_model(
task = task + "-with-past"

logger.info(f"Automatic task detection to: {task}.")

stateful = stateful and (
ensure_export_task_support_stateful(task)
ensure_export_task_support_stateful(task, getattr(getattr(model, "config", {}), "is_encoder_decoder", False))
or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", ""))
)
# TODO: support onnx_config.py in the model repo
Expand Down Expand Up @@ -647,13 +649,27 @@ def export_from_model(
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)

logging.disable(logging.INFO)

if library_name == "open_clip":
custom_architecture = True
custom_export_configs, fn_get_submodels = _get_open_clip_submodels_fn_and_export_configs(
model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels
)

if library_name == "diffusers":
elif (
stateful
and (
task.startswith(TasksManager._ENCODER_DECODER_TASKS) and getattr(model.config, "is_encoder_decoder", False)
)
and not custom_architecture
):
export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export(
model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default"
)
stateful_submodels = [False, True]

elif library_name == "diffusers":
export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
stateful_submodels = False
else:
Expand Down Expand Up @@ -1193,3 +1209,45 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)

return models_for_export


def _get_encoder_decoder_stateful_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
task: str,
_variant: str,
library_name: str,
int_dtype: str = "int64",
float_dtype: str = "fp32",
preprocessors: Optional[List[Any]] = None,
):
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="openvino", task=task, library_name=library_name
)
export_config = export_config_constructor(
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
legacy=False,
)

export_config.variant = _variant
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")

models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True)

encoder_export_config = export_config.with_behavior("encoder")
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)

decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)

decoder_export_config_with_past.stateful = True
decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME)
models_for_export[DECODER_NAME] = (
decoder_with_past_model,
decoder_export_config_with_past,
)
logger.info(models_for_export.keys())

return None, models_for_export
217 changes: 215 additions & 2 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@

from packaging import version
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from optimum.exporters.onnx.base import ConfigBehavior
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.config import (
TextDecoderOnnxConfig,
TextDecoderWithPositionIdsOnnxConfig,
)
from optimum.exporters.onnx.model_configs import (
CLIPOnnxConfig,
CLIPTextOnnxConfig,
Expand All @@ -34,14 +38,20 @@
IBertOnnxConfig,
LlamaOnnxConfig,
MistralOnnxConfig,
M2M100OnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
T5OnnxConfig,
UNetOnnxConfig,
VisionOnnxConfig,
VaeDecoderOnnxConfig,
VaeEncoderOnnxConfig,
WhisperOnnxConfig,
)
from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.model_patcher import ModelPatcher
from optimum.exporters.tasks import TasksManager
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils import DEFAULT_DUMMY_SHAPES, DummyInputGenerator
from optimum.utils.input_generators import (
DTYPE_MAPPER,
DummyInputGenerator,
Expand Down Expand Up @@ -90,6 +100,7 @@
QwenModelPatcher,
RotaryEmbPatcher,
UpdateCausalMaskModelPatcher,
WhisperStatefulDecoderPatcher,
XverseModelPatcher,
)

Expand Down Expand Up @@ -2204,3 +2215,205 @@ def patch_model_for_export(
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)


@register_in_tasks_manager(
"whisper",
*[
"feature-extraction",
"feature-extraction-with-past",
"audio-classification",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
],
library_name="transformers",
)
class WhisperOpenVINOConfig(WhisperOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
print("HERE")
return WhisperStatefulDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)


@register_in_tasks_manager(
"t5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class T5OpenVINOConfig(T5OnnxConfig):
def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](
self.task, self._normalized_config, **kwargs
)
dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1](
self.task,
self._normalized_config,
**kwargs,
)
dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2](
self.task,
self._normalized_config,
encoder_sequence_length=dummy_text_input_generator.sequence_length
if not getattr(self, "stateful", False)
else dummy_text_input_generator.sequence_length + 2,
**kwargs,
)
dummy_inputs_generators = [
dummy_text_input_generator,
dummy_decoder_text_input_generator,
dummy_seq2seq_past_key_values_generator,
]

return dummy_inputs_generators


@register_in_tasks_manager(
"mt5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class MT5OpenVINOConfig(T5OpenVINOConfig):
pass


@register_in_tasks_manager(
"longt5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class LongT5OpenVINOConfig(T5OpenVINOConfig):
pass


@register_in_tasks_manager(
"m2m-100",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class M2M100OpenVINOConfig(M2M100OnnxConfig):
def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](
self.task, self._normalized_config, **kwargs
)
task = "feature-extraction" if self.task != "text-generation" else "text-generation"
dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1][task](
self.task, self._normalized_config, **kwargs
)
if self.task != "text-generation":
kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length
if getattr(self, "stateful", False):
kwargs["encoder_sequence_length"] = kwargs["encoder_sequence_length"] + 2

dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task](
self.task, self._normalized_config, **kwargs
)
dummy_inputs_generators = [
dummy_text_input_generator,
dummy_decoder_text_input_generator,
dummy_seq2seq_past_key_values_generator,
]

return dummy_inputs_generators


@register_in_tasks_manager(
"bart",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text2text-generation",
"text2text-generation-with-past",
"text-classification",
"question-answering",
],
library_name="transformers",
)
class BartOpenVINOConfig(M2M100OpenVINOConfig):
pass


@register_in_tasks_manager(
"mbart",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text2text-generation",
"text2text-generation-with-past",
"text-classification",
"question-answering",
],
library_name="transformers",
)
class MBartOpenVINOConfig(M2M100OpenVINOConfig):
pass


@register_in_tasks_manager(
"blenderbot",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text2text-generation",
"text2text-generation-with-past",
],
library_name="transformers",
)
class BlenderbotOpenVINOConfig(M2M100OpenVINOConfig):
pass


@register_in_tasks_manager(
"blenderbot-small",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text2text-generation",
"text2text-generation-with-past",
],
library_name="transformers",
)
class BlenderbotSmallOpenVINOConfig(M2M100OpenVINOConfig):
pass


@register_in_tasks_manager(
"marian",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text2text-generation",
"text2text-generation-with-past",
],
library_name="transformers",
)
class MarianOpenVINOConfig(M2M100OpenVINOConfig):
pass


@register_in_tasks_manager(
"pegasus",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text2text-generation",
"text2text-generation-with-past",
],
library_name="transformers",
)
class PegasusOpenVINOConfig(M2M100OpenVINOConfig):
pass
Loading
Loading