Skip to content

Commit

Permalink
small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jun 26, 2024
1 parent 18ac8c0 commit ec24e13
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 183 deletions.
7 changes: 3 additions & 4 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,6 @@ def _get_encoder_decoder_stateful_models_for_export(
float_dtype: str = "fp32",
preprocessors: Optional[List[Any]] = None,
):
logger.info("HERE")
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="openvino", task=task, library_name=library_name
)
Expand All @@ -789,9 +788,9 @@ def _get_encoder_decoder_stateful_models_for_export(
encoder_export_config = export_config.with_behavior("encoder")
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)

decoder_export_config_with_past = export_config.with_behavior(
"decoder", use_past=True, use_past_in_inputs=True, stateful=True
)
decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)

decoder_export_config_with_past.stateful = True
decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME)
models_for_export[DECODER_NAME] = (
decoder_with_past_model,
Expand Down
257 changes: 79 additions & 178 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from packaging import version
from transformers import PretrainedConfig
from transformers.utils import is_tf_available

from optimum.exporters.onnx.base import ConfigBehavior
from optimum.exporters.onnx.config import (
OnnxSeq2SeqConfigWithPast,
TextDecoderOnnxConfig,
TextDecoderWithPositionIdsOnnxConfig,
)
Expand All @@ -33,7 +30,9 @@
M2M100OnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
Pix2StructOnnxConfig,
T5OnnxConfig,
TrOCROnnxConfig,
UNetOnnxConfig,
VaeDecoderOnnxConfig,
VaeEncoderOnnxConfig,
Expand Down Expand Up @@ -847,140 +846,26 @@ def patch_model_for_export(
library_name="transformers",
)
class WhisperOpenVINOConfig(WhisperOnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
stateful: bool = False,
):
self.stateful = stateful
super().__init__(
config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy
)

def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
"""
Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`.
Each dummy input generator is independent, so this method instantiates the first generator, and
forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch
size. Override this method for custom behavior.
"""
if self.stateful:
if getattr(self, "stateful"):
if "encoder_sequence_length" not in kwargs:
sequence_len = kwargs.get("sequence_length", DEFAULT_DUMMY_SHAPES["sequence_length"])
kwargs["encoder_sequence_length"] = sequence_len + 2
return super()._create_dummy_input_generator_classes(**kwargs)

def with_behavior(
self,
behavior: Union[str, ConfigBehavior],
use_past: bool = False,
use_past_in_inputs: bool = False,
stateful: bool = False,
) -> "OnnxSeq2SeqConfigWithPast":
"""
Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value.
Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
use_past (`bool`, defaults to `False`):
Whether or not the ONNX config to instantiate is for a model using KV cache.
use_past_in_inputs (`bool`, defaults to `False`):
Whether the KV cache is to be passed as an input to the ONNX.
Returns:
`OnnxSeq2SeqConfigWithPast`
"""
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior):
behavior = ConfigBehavior(behavior)

onnx_config = self.__class__(
self._config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
legacy=self.legacy,
stateful=stateful,
)
onnx_config.variant = self.variant
return onnx_config


@register_in_tasks_manager(
"t5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class T5OpenVINOConfig(T5OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
stateful: bool = False,
):
self.stateful = stateful
super().__init__(
config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy
)

def with_behavior(
self,
behavior: Union[str, ConfigBehavior],
use_past: bool = False,
use_past_in_inputs: bool = False,
stateful: bool = False,
) -> "OnnxSeq2SeqConfigWithPast":
"""
Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value.
Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
use_past (`bool`, defaults to `False`):
Whether or not the ONNX config to instantiate is for a model using KV cache.
use_past_in_inputs (`bool`, defaults to `False`):
Whether the KV cache is to be passed as an input to the ONNX.
Returns:
`OnnxSeq2SeqConfigWithPast`
"""
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior):
behavior = ConfigBehavior(behavior)

onnx_config = self.__class__(
self._config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
legacy=self.legacy,
stateful=stateful,
)
onnx_config.variant = self.variant
return onnx_config

def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](
self.task, self._normalized_config, **kwargs
Expand All @@ -994,7 +879,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen
self.task,
self._normalized_config,
encoder_sequence_length=dummy_text_input_generator.sequence_length
if not self.stateful
if not getattr(self, "stateful", False)
else dummy_text_input_generator.sequence_length + 2,
**kwargs,
)
Expand Down Expand Up @@ -1031,63 +916,6 @@ class LongT5OpenVINOConfig(T5OpenVINOConfig):
library_name="transformers",
)
class M2M100OpenVINOConfig(M2M100OnnxConfig):
def __init__(
self,
config: PretrainedConfig,
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
stateful: bool = False,
):
self.stateful = stateful
super().__init__(
config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy
)

def with_behavior(
self,
behavior: Union[str, ConfigBehavior],
use_past: bool = False,
use_past_in_inputs: bool = False,
stateful: bool = False,
) -> "OnnxSeq2SeqConfigWithPast":
"""
Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value.
Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
use_past (`bool`, defaults to `False`):
Whether or not the ONNX config to instantiate is for a model using KV cache.
use_past_in_inputs (`bool`, defaults to `False`):
Whether the KV cache is to be passed as an input to the ONNX.
Returns:
`OnnxSeq2SeqConfigWithPast`
"""
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior):
behavior = ConfigBehavior(behavior)

onnx_config = self.__class__(
self._config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
legacy=self.legacy,
stateful=stateful,
)
onnx_config.variant = self.variant
return onnx_config

def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](
self.task, self._normalized_config, **kwargs
Expand All @@ -1098,7 +926,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen
)
if self.task != "text-generation":
kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length
if self.stateful:
if getattr(self, "stateful", False):
kwargs["encoder_sequence_length"] = kwargs["encoder_sequence_length"] + 2

dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task](
Expand Down Expand Up @@ -1209,5 +1037,78 @@ class MarianOpenVINOConfig(M2M100OpenVINOConfig):
],
library_name="transformers",
)
class PegasusOpenVINOConfig(M2M100OnnxConfig):
class PegasusOpenVINOConfig(M2M100OpenVINOConfig):
pass


@register_in_tasks_manager(
"pix2struct",
*[
"image-to-text",
"image-to-text-with-past",
"visual-question-answering",
"visual-question-answering-with-past",
],
library_name="transformers",
)
class Pix2StructOpenVINOConfig(Pix2StructOnnxConfig):
def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_inputs_generators = []
dummy_inputs_generators.append(self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config))

if self._preprocessors is None or len(self._preprocessors) != 2:
raise ValueError(
f"Preprocessors for pix2struct need to be available for the ONNX export to infer input static shapes. Got: {self._preprocessors}"
)

encoder_sequence_length = self._preprocessors[1].image_processor.max_patches
if getattr(self, "stateful", False):
encoder_sequence_length += 2
# A hack for DummyPix2StructInputGenerator to gain access to the preprocessors.
# TODO: we should probably pass preprocessors to all dummy input generators.
kwargs["preprocessors"] = self._preprocessors
for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:]:
dummy_inputs_generators.append(
cls_(self.task, self._normalized_config, encoder_sequence_length=encoder_sequence_length, **kwargs)
)

return dummy_inputs_generators


@register_in_tasks_manager(
"trocr",
*[
"feature-extraction",
"feature-extraction-with-past",
"image-to-text",
"image-to-text-with-past",
],
library_name="transformers",
)
class TrOCROpenVINOConfig(TrOCROnnxConfig):
def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](
self.task, self._normalized_config, **kwargs
)
dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1](
self.task,
self._normalized_config,
**kwargs,
)
encoder_sequence_length = dummy_text_input_generator.sequence_length

if getattr(self, "stateful", False):
encoder_sequence_length += 2
dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2](
self.task,
self._normalized_config,
encoder_sequence_length=encoder_sequence_length,
**kwargs,
)
dummy_inputs_generators = [
dummy_text_input_generator,
dummy_decoder_text_input_generator,
dummy_seq2seq_past_key_values_generator,
]

return dummy_inputs_generators
3 changes: 2 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,8 @@ def test_compare_to_transformers(self, model_arch):

self.assertIsInstance(ov_model.encoder, OVEncoder)
self.assertIsInstance(ov_model.decoder, OVDecoder)
self.assertIsInstance(ov_model.decoder_with_past, OVDecoder)
self.assertTrue(ov_model.decoder.stateful)
self.assertIsInstance(ov_model.decoder_with_past, None)
self.assertIsInstance(ov_model.config, PretrainedConfig)

transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
Expand Down

0 comments on commit ec24e13

Please sign in to comment.