diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index d03478b0f3..0ce9268646 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -23,9 +23,8 @@ from optimum.exporters import TasksManager from optimum.exporters.onnx.base import OnnxConfig -from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from optimum.exporters.openvino.convert import export_from_model -from optimum.intel.utils.import_utils import is_openvino_tokenizers_available, is_transformers_version +from optimum.intel.utils.import_utils import is_openvino_tokenizers_available from optimum.utils.save_utils import maybe_load_preprocessors diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index f53e3bb07f..63e73f830e 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -32,7 +32,13 @@ from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx -from optimum.exporters.utils import _get_submodels_and_export_configs, ENCODER_NAME, DECODER_NAME, DECODER_WITH_PAST_NAME, _get_submodels_for_export_encoder_decoder +from optimum.exporters.utils import ( + DECODER_NAME, + DECODER_WITH_PAST_NAME, + ENCODER_NAME, + _get_submodels_and_export_configs, + _get_submodels_for_export_encoder_decoder, +) from optimum.intel.utils.import_utils import ( _nncf_version, _optimum_intel_version, @@ -576,13 +582,13 @@ def export_from_model( logging.disable(logging.INFO) - if (model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS)) and stateful and not custom_architecture: + if ( + (model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS)) + and stateful + and not custom_architecture + ): export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( - model=model, - task=task, - preprocessors=preprocessors, - library_name=library_name, - _variant="default" + model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default" ) stateful = [False, True] else: @@ -773,9 +779,7 @@ def _get_encoder_decoder_stateful_models_for_export( ) export_config.variant = _variant - all_variants = "\n".join( - [f" - {name}: {description}" for name, description in export_config.VARIANTS.items()] - ) + all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()]) logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}") models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True) @@ -783,7 +787,9 @@ def _get_encoder_decoder_stateful_models_for_export( encoder_export_config = export_config.with_behavior("encoder") models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config) - decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True, stateful=True) + decoder_export_config_with_past = export_config.with_behavior( + "decoder", use_past=True, use_past_in_inputs=True, stateful=True + ) decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME) models_for_export[DECODER_NAME] = ( decoder_with_past_model, @@ -791,5 +797,4 @@ def _get_encoder_decoder_stateful_models_for_export( ) logger.info(models_for_export.keys()) - return None, models_for_export diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 509f35031d..2803c6966b 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -14,29 +14,30 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from optimum.exporters.onnx.base import ConfigBehavior from packaging import version from transformers import PretrainedConfig from transformers.utils import is_tf_available -from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig +from optimum.exporters.onnx.base import ConfigBehavior +from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig, OnnxSeq2SeqConfigWithPast from optimum.exporters.onnx.model_configs import ( CodeGenOnnxConfig, FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, + M2M100OnnxConfig, MPTOnnxConfig, PhiOnnxConfig, + T5OnnxConfig, UNetOnnxConfig, VaeDecoderOnnxConfig, VaeEncoderOnnxConfig, - WhisperOnnxConfig + WhisperOnnxConfig, ) from optimum.exporters.tasks import TasksManager from optimum.utils import DEFAULT_DUMMY_SHAPES, DummyInputGenerator from optimum.utils.input_generators import ( - DummyInputGenerator, DummyPastKeyValuesGenerator, DummyTextInputGenerator, FalconDummyPastKeyValuesGenerator, @@ -830,12 +831,35 @@ def patch_model_for_export( return ArcticModelPatcher(self, model, model_kwargs=model_kwargs) -@register_in_tasks_manager("whisper", *["feature-extraction", "feature-extraction-with-past", "audio-classification", "automatic-speech-recognition", "automatic-speech-recognition-with-past",], library_name="transformers") +@register_in_tasks_manager( + "whisper", + *[ + "feature-extraction", + "feature-extraction-with-past", + "audio-classification", + "automatic-speech-recognition", + "automatic-speech-recognition-with-past", + ], + library_name="transformers", +) class WhisperOpenVINOConfig(WhisperOnnxConfig): - def __init__(self, config: PretrainedConfig, task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, legacy: bool = False, stateful: bool = False): + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + legacy: bool = False, + stateful: bool = False, + ): self.stateful = stateful - super().__init__(config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy) - + super().__init__( + config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy + ) def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: """ @@ -851,40 +875,335 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene return super()._create_dummy_input_generator_classes(**kwargs) def with_behavior( - self, - behavior: Union[str, ConfigBehavior], - use_past: bool = False, - use_past_in_inputs: bool = False, - stateful: bool = False - ) -> "OnnxSeq2SeqConfigWithPast": - """ - Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. - - Args: - behavior ([`ConfigBehavior`]): - The behavior to use for the new instance. - use_past (`bool`, defaults to `False`): - Whether or not the ONNX config to instantiate is for a model using KV cache. - use_past_in_inputs (`bool`, defaults to `False`): - Whether the KV cache is to be passed as an input to the ONNX. - - Returns: - `OnnxSeq2SeqConfigWithPast` - """ - if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): - behavior = ConfigBehavior(behavior) - - onnx_config = self.__class__( - self._config, - task=self.task, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - use_past=use_past, - use_past_in_inputs=use_past_in_inputs, - behavior=behavior, - preprocessors=self._preprocessors, - legacy=self.legacy, - stateful=stateful - ) - onnx_config.variant = self.variant - return onnx_config + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + stateful: bool = False, + ) -> "OnnxSeq2SeqConfigWithPast": + """ + Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + use_past (`bool`, defaults to `False`): + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. + + Returns: + `OnnxSeq2SeqConfigWithPast` + """ + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + legacy=self.legacy, + stateful=stateful, + ) + onnx_config.variant = self.variant + return onnx_config + + +@register_in_tasks_manager( + "t5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class T5OpenVINOConfig(T5OnnxConfig): + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + legacy: bool = False, + stateful: bool = False, + ): + self.stateful = stateful + super().__init__( + config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy + ) + + def with_behavior( + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + stateful: bool = False, + ) -> "OnnxSeq2SeqConfigWithPast": + """ + Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + use_past (`bool`, defaults to `False`): + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. + + Returns: + `OnnxSeq2SeqConfigWithPast` + """ + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + legacy=self.legacy, + stateful=stateful, + ) + onnx_config.variant = self.variant + return onnx_config + + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( + self.task, self._normalized_config, **kwargs + ) + dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1]( + self.task, + self._normalized_config, + **kwargs, + ) + dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2]( + self.task, + self._normalized_config, + encoder_sequence_length=dummy_text_input_generator.sequence_length + if not self.stateful + else dummy_text_input_generator.sequence_length + 2, + **kwargs, + ) + dummy_inputs_generators = [ + dummy_text_input_generator, + dummy_decoder_text_input_generator, + dummy_seq2seq_past_key_values_generator, + ] + + return dummy_inputs_generators + + +@register_in_tasks_manager( + "mt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class MT5OpenVINOConfig(T5OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "longt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class LongT5OpenVINOConfig(T5OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "m2m-100", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class M2M100OpenVINOConfig(M2M100OnnxConfig): + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + legacy: bool = False, + stateful: bool = False, + ): + self.stateful = stateful + super().__init__( + config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy + ) + + def with_behavior( + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + stateful: bool = False, + ) -> "OnnxSeq2SeqConfigWithPast": + """ + Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + use_past (`bool`, defaults to `False`): + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. + + Returns: + `OnnxSeq2SeqConfigWithPast` + """ + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + legacy=self.legacy, + stateful=stateful, + ) + onnx_config.variant = self.variant + return onnx_config + + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( + self.task, self._normalized_config, **kwargs + ) + task = "feature-extraction" if self.task != "text-generation" else "text-generation" + dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1][task]( + self.task, self._normalized_config, **kwargs + ) + if self.task != "text-generation": + kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length + if self.stateful: + kwargs["encoder_sequence_length"] = kwargs["encoder_sequence_length"] + 2 + + dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task]( + self.task, self._normalized_config, **kwargs + ) + dummy_inputs_generators = [ + dummy_text_input_generator, + dummy_decoder_text_input_generator, + dummy_seq2seq_past_key_values_generator, + ] + + return dummy_inputs_generators + + +@register_in_tasks_manager( + "bart", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", + "question-answering", + ], + library_name="transformers", +) +class BartOpenVINOConfig(M2M100OnnxConfig): + pass + + +@register_in_tasks_manager( + "mbart", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", + "question-answering", + ], + library_name="transformers", +) +class MBartOpenVINOConfig(M2M100OnnxConfig): + pass + + +@register_in_tasks_manager( + "blenderbot", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class BlenderbotOpenVINOConfig(M2M100OnnxConfig): + pass + + +@register_in_tasks_manager( + "blenderbot-small", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class BlenderbotSmallOpenVINOConfig(M2M100OnnxConfig): + pass + + +@register_in_tasks_manager( + "marian", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class MarianOpenVINOConfig(M2M100OnnxConfig): + pass + + +@register_in_tasks_manager( + "pegasus", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class PegasusOpenVINOConfig(M2M100OnnxConfig): + pass diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index c2a2d07174..8003882176 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -182,11 +182,11 @@ def ensure_stateful_is_available(warn=True): return True -def ensure_export_task_support_stateful(task: str, is_encoder_decoder:bool = False): +def ensure_export_task_support_stateful(task: str, is_encoder_decoder: bool = False): task = TasksManager.map_from_synonym(task) if not is_encoder_decoder: return task == "text-generation-with-past" - + _ENCODER_DECODER_TASKS_WITH_PAST = ( "automatic-speech-recognition", "document-question-answering", @@ -197,7 +197,7 @@ def ensure_export_task_support_stateful(task: str, is_encoder_decoder:bool = Fal is_stateful = task.endswith("-with-past") and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST return is_stateful - + def remove_parameters_by_names(model: ov.Model, names: list): parameters = [model.input(name).get_node() for name in names] @@ -295,7 +295,14 @@ def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model): def patch_stateful_encoder_decoder(config, ov_model): - encoder_key_value_input_names = [key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())] + encoder_key_value_input_names = [ + key.get_any_name() + for key in ov_model.inputs + if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names()) + ] remove_parameters_by_names(ov_model, encoder_key_value_input_names) patch_stateful_decoder(config, ov_model) - insert_state_for_nodes(ov_model, find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()])) \ No newline at end of file + insert_state_for_nodes( + ov_model, + find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]), + ) diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index d144866f91..b3d368e7fc 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -37,8 +37,8 @@ from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE -from ...exporters.openvino.stateful import model_has_state +from ...exporters.openvino.stateful import model_has_state from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM from .utils import _print_compiled_model_properties @@ -342,7 +342,7 @@ def to(self, device: str): self._device = device.upper() self.encoder._device = self._device self.decoder._device = self._device - if self.use_cache and self.decoder_with_past_model is not None: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past._device = self._device self.clear_requests() else: @@ -379,7 +379,7 @@ def forward( encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, - past_key_values = past_key_values + past_key_values=past_key_values, ) else: decoder_outputs = self.decoder_with_past( @@ -419,7 +419,7 @@ def get_encoder(self): return self.encoder def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: - self.decoder._reorder_cache(past, beam_idx) + self.decoder._reorder_cache(past, beam_idx) def reshape(self, batch_size: int, sequence_length: int): """ @@ -446,13 +446,13 @@ def half(self): def clear_requests(self): self.encoder.request = None self.decoder.request = None - if self.use_cache and self.decoder_with_past_model is not None: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past.request = None def compile(self): self.encoder._compile() self.decoder._compile() - if self.use_cache and self.decoder_with_past_model is not None: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past._compile()