diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index dba4628d79..639286988b 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -264,12 +264,11 @@ def main_export( f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." ) - if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: - loading_kwargs["attn_implementation"] = "eager" - # some models force flash_attn attention by default that does not support load model on cpu if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES: loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type] + # if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: + # loading_kwargs["attn_implementation"] = "eager" # there are some difference between remote and in library representation of past key values for some models, # for avoiding confusion we disable remote code for them if ( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index fdcfbecf53..38bebc4f8b 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -35,9 +35,10 @@ from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx from optimum.exporters.utils import ( _get_submodels_and_export_configs as _default_get_submodels_and_export_configs, -) -from optimum.exporters.utils import ( - get_diffusion_models_for_export, + DECODER_NAME, + DECODER_WITH_PAST_NAME, + ENCODER_NAME, + _get_submodels_for_export_encoder_decoder, ) from optimum.intel.utils.import_utils import ( _diffusers_version, @@ -534,6 +535,8 @@ def export_models( f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export." ) + if not isinstance(stateful, (list, tuple)): + stateful = [stateful] * len(models_and_export_configs) for i, model_name in enumerate(models_and_export_configs.keys()): submodel, sub_export_config = models_and_export_configs[model_name] output_name = output_names[i] if output_names is not None else Path(model_name + ".xml") @@ -549,9 +552,9 @@ def export_models( input_shapes=input_shapes, model_kwargs=model_kwargs, ov_config=ov_config, - stateful=stateful[i] if isinstance(stateful, (list, tuple)) else stateful, patch_16bit_model=patch_16bit_model, library_name=library_name, + stateful=stateful[i], ) ) @@ -613,9 +616,8 @@ def export_from_model( task = task + "-with-past" logger.info(f"Automatic task detection to: {task}.") - stateful = stateful and ( - ensure_export_task_support_stateful(task) + ensure_export_task_support_stateful(task, getattr(getattr(model, "config", {}), "is_encoder_decoder", False)) or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", "")) ) # TODO: support onnx_config.py in the model repo @@ -647,13 +649,27 @@ def export_from_model( kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] ) + logging.disable(logging.INFO) + if library_name == "open_clip": custom_architecture = True custom_export_configs, fn_get_submodels = _get_open_clip_submodels_fn_and_export_configs( model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels ) - if library_name == "diffusers": + elif ( + stateful + and ( + task.startswith(TasksManager._ENCODER_DECODER_TASKS) and getattr(model.config, "is_encoder_decoder", False) + ) + and not custom_architecture + ): + export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( + model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default" + ) + stateful_submodels = [False, True] + + elif library_name == "diffusers": export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino") stateful_submodels = False else: @@ -1193,3 +1209,45 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype): models_for_export["text_encoder_2"] = (text_encoder_2, export_config) return models_for_export + + +def _get_encoder_decoder_stateful_models_for_export( + model: Union["PreTrainedModel", "TFPreTrainedModel"], + task: str, + _variant: str, + library_name: str, + int_dtype: str = "int64", + float_dtype: str = "fp32", + preprocessors: Optional[List[Any]] = None, +): + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="openvino", task=task, library_name=library_name + ) + export_config = export_config_constructor( + model.config, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=False, + ) + + export_config.variant = _variant + all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()]) + logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}") + + models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True) + + encoder_export_config = export_config.with_behavior("encoder") + models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config) + + decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) + + decoder_export_config_with_past.stateful = True + decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME) + models_for_export[DECODER_NAME] = ( + decoder_with_past_model, + decoder_export_config_with_past, + ) + logger.info(models_for_export.keys()) + + return None, models_for_export diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index b8310882ba..5427f92839 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -18,9 +18,13 @@ from packaging import version from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel +from optimum.exporters.onnx.base import ConfigBehavior from transformers.utils import is_tf_available -from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig +from optimum.exporters.onnx.config import ( + TextDecoderOnnxConfig, + TextDecoderWithPositionIdsOnnxConfig, +) from optimum.exporters.onnx.model_configs import ( CLIPOnnxConfig, CLIPTextOnnxConfig, @@ -34,14 +38,20 @@ IBertOnnxConfig, LlamaOnnxConfig, MistralOnnxConfig, + M2M100OnnxConfig, MPTOnnxConfig, PhiOnnxConfig, + T5OnnxConfig, UNetOnnxConfig, VisionOnnxConfig, + VaeDecoderOnnxConfig, + VaeEncoderOnnxConfig, + WhisperOnnxConfig, ) +from optimum.exporters.onnx.base import OnnxConfig from optimum.exporters.onnx.model_patcher import ModelPatcher from optimum.exporters.tasks import TasksManager -from optimum.utils import DEFAULT_DUMMY_SHAPES +from optimum.utils import DEFAULT_DUMMY_SHAPES, DummyInputGenerator from optimum.utils.input_generators import ( DTYPE_MAPPER, DummyInputGenerator, @@ -90,6 +100,7 @@ QwenModelPatcher, RotaryEmbPatcher, UpdateCausalMaskModelPatcher, + WhisperStatefulDecoderPatcher, XverseModelPatcher, ) @@ -2204,3 +2215,205 @@ def patch_model_for_export( if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs) return super().patch_model_for_export(model, model_kwargs) + + +@register_in_tasks_manager( + "whisper", + *[ + "feature-extraction", + "feature-extraction-with-past", + "audio-classification", + "automatic-speech-recognition", + "automatic-speech-recognition-with-past", + ], + library_name="transformers", +) +class WhisperOpenVINOConfig(WhisperOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER: + print("HERE") + return WhisperStatefulDecoderPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + +@register_in_tasks_manager( + "t5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class T5OpenVINOConfig(T5OnnxConfig): + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( + self.task, self._normalized_config, **kwargs + ) + dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1]( + self.task, + self._normalized_config, + **kwargs, + ) + dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2]( + self.task, + self._normalized_config, + encoder_sequence_length=dummy_text_input_generator.sequence_length + if not getattr(self, "stateful", False) + else dummy_text_input_generator.sequence_length + 2, + **kwargs, + ) + dummy_inputs_generators = [ + dummy_text_input_generator, + dummy_decoder_text_input_generator, + dummy_seq2seq_past_key_values_generator, + ] + + return dummy_inputs_generators + + +@register_in_tasks_manager( + "mt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class MT5OpenVINOConfig(T5OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "longt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class LongT5OpenVINOConfig(T5OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "m2m-100", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class M2M100OpenVINOConfig(M2M100OnnxConfig): + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( + self.task, self._normalized_config, **kwargs + ) + task = "feature-extraction" if self.task != "text-generation" else "text-generation" + dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1][task]( + self.task, self._normalized_config, **kwargs + ) + if self.task != "text-generation": + kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length + if getattr(self, "stateful", False): + kwargs["encoder_sequence_length"] = kwargs["encoder_sequence_length"] + 2 + + dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task]( + self.task, self._normalized_config, **kwargs + ) + dummy_inputs_generators = [ + dummy_text_input_generator, + dummy_decoder_text_input_generator, + dummy_seq2seq_past_key_values_generator, + ] + + return dummy_inputs_generators + + +@register_in_tasks_manager( + "bart", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", + "question-answering", + ], + library_name="transformers", +) +class BartOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "mbart", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", + "question-answering", + ], + library_name="transformers", +) +class MBartOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "blenderbot", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class BlenderbotOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "blenderbot-small", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class BlenderbotSmallOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "marian", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class MarianOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "pegasus", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class PegasusOpenVINOConfig(M2M100OpenVINOConfig): + pass diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 58659e637b..88b417a3f1 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -24,7 +24,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from transformers.utils import is_tf_available -from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments +from optimum.exporters.onnx.model_patcher import ( + DecoderModelPatcher, + ModelPatcher, + override_arguments, + Seq2SeqModelPatcher, +) from optimum.intel.utils.import_utils import ( _openvino_version, _torch_version, @@ -3237,3 +3242,51 @@ def __init__( def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + + +class WhisperStatefulDecoderPatcher(Seq2SeqModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + model.__orig_forward = model.forward + + @functools.wraps(model.__orig_forward) + def patched_forward(*args, **kwargs): + from transformers.cache_utils import EncoderDecoderCache + + print("HERE!!!") + + signature = inspect.signature(self.orig_forward) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + + return_legacy_cache = False + pkv_in_args = False + legacy_pkv = None + if "past_key_values" in kwargs: + legacy_pkv = kwargs.pop("past_key_values", None) + sign_names = list(signature.parameters.keys()) + pkv_argument_index = sign_names.index("past_key_values") + if legacy_pkv is None and len(args) > pkv_argument_index: + legacy_pkv = args[pkv_argument_index] + pkv_in_args = True + if legacy_pkv is not None: + only_self_cache = [cache_item[:2] for cache_item in legacy_pkv] + pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache) + return_legacy_cache = True + if not pkv_in_args: + kwargs["past_key_values"] = pkv + else: + args[pkv_argument_index] = pkv + + outputs = model.__orig_forward(*args, **kwargs) + if return_legacy_cache: + outputs.past_key_values = outputs.past_key_values.to_legacy_cache() + + return outputs + + model.forward = patched_forward + + super().__init__(config, model, model_kwargs) diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 4b4374ab51..cf58a3693e 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -191,9 +191,76 @@ def ensure_stateful_is_available(warn=True): return True -def ensure_export_task_support_stateful(task: str): +def ensure_export_task_support_stateful(task: str, is_encoder_decoder: bool = False): task = TasksManager.map_from_synonym(task) - return task in ["text-generation-with-past"] + if not is_encoder_decoder: + return task == "text-generation-with-past" + + _ENCODER_DECODER_TASKS_WITH_PAST = ( + "automatic-speech-recognition", + "text2text-generation", + ) + + is_stateful = task.endswith("-with-past") and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + return is_stateful + + +def remove_parameters_by_names(model: ov.Model, names: list): + parameters = [model.input(name).get_node() for name in names] + for p in parameters: + model.remove_parameter(p) + + +def get_input_nodes(node): + return [input.get_node() for input in node.input_values()] + + +def find_dependent_nodes(model: ov.Model, sources: list): + # Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources` + result = set(sources) + for node in model.get_ordered_ops(): + input_nodes = set(get_input_nodes(node)) + if input_nodes & result: + result.add(node) + return result + + +def get_read_value_ops(model: ov.Model): + return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"] + + +def get_shape_of_ops(model: ov.Model): + return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"] + + +def get_consumer_nodes(node): + consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()]) + return set(input.get_node() for input in consumer_inputs) + + +def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list): + # Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's + other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources) + other_nodes = find_dependent_nodes(model, other_inputs) + source_dependent_nodes = find_dependent_nodes(model, sources) + # TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph + nodes = source_dependent_nodes - other_nodes + edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes] + return edge_nodes + + +def insert_state_for_nodes(model: ov.Model, nodes): + # For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression + outputs = sum((node.outputs() for node in nodes), []) + for output in outputs: + consumers = output.get_target_inputs() + # FIXME: get_any_name is not reliable as tensor may not have any names + variable_id = output.get_any_name() + read_value = ov.runtime.opset13.read_value(output, variable_id) + for consumer in consumers: + consumer.replace_source_output(read_value.output(0)) + assign = ov.runtime.opset13.assign(read_value, variable_id) + model.add_sinks([assign]) def ensure_model_type_support_stateful(model_type: str): @@ -212,6 +279,12 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name openvino model """ + if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"): + return patch_stateful_encoder_decoder(config, ov_model) + return patch_stateful_decoder(config, ov_model) + + +def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model): key_value_input_names = [ key_name for key in ov_model.inputs for key_name in key.get_names() if "key_values" in key_name ] @@ -235,3 +308,19 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name make_stateful( ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None ) + + +def patch_stateful_encoder_decoder(config, ov_model): + log.warn(ov_model) + encoder_key_value_input_names = [ + key.get_any_name() + for key in ov_model.inputs + if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names()) + ] + log.warn(encoder_key_value_input_names) + remove_parameters_by_names(ov_model, encoder_key_value_input_names) + patch_stateful_decoder(config, ov_model) + insert_state_for_nodes( + ov_model, + find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]), + ) diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 0ce15641fe..d16c533441 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -26,6 +26,7 @@ from ...exporters.openvino import main_export from ..utils.import_utils import is_transformers_version +from ...exporters.openvino.stateful import model_has_state from .configuration import OVConfig, OVWeightQuantizationConfig from .modeling_base import OVBaseModel from .utils import ( @@ -64,7 +65,7 @@ def __init__( **kwargs, ): self.config = config - self.use_cache = decoder_with_past is not None + self.use_cache = decoder_with_past is not None or model_has_state(decoder) self.model_save_dir = model_save_dir self._compile_only = kwargs.get("compile_only", False) self._device = device.upper() @@ -75,7 +76,8 @@ def __init__( if self.is_dynamic and not self._compile_only: encoder = self._reshape(encoder, -1, -1, is_decoder=False) decoder = self._reshape(decoder, -1, -1) - decoder_with_past = self._reshape(decoder_with_past, -1, -1) if self.use_cache else None + if decoder_with_past is not None: + decoder_with_past = self._reshape(decoder_with_past, -1, -1) if self.use_cache else None self.encoder_model = encoder self.decoder_model = decoder self.decoder_with_past_model = decoder_with_past @@ -115,7 +117,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]): """ src_files = [self.encoder_model, self.decoder_model] dst_file_names = [OV_ENCODER_NAME, OV_DECODER_NAME] - if self.use_cache: + if self.decoder_with_past_model is not None: src_files.append(self.decoder_with_past_model) dst_file_names.append(OV_DECODER_WITH_PAST_NAME) @@ -204,7 +206,7 @@ def _from_pretrained( if not compile_only: encoder = cls.load_model(os.path.join(model_id, encoder_file_name), quantization_config) decoder = cls.load_model(os.path.join(model_id, decoder_file_name), quantization_config) - if use_cache: + if use_cache and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)): decoder_with_past = cls.load_model( os.path.join(model_id, decoder_with_past_file_name), quantization_config ) @@ -221,7 +223,7 @@ def _from_pretrained( kwargs.get("ov_config"), model_save_dir, ) - if use_cache: + if use_cache and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)): decoder_with_past = cls._compile_model( os.path.join(model_id, decoder_with_past_file_name), kwargs.get("device", "CPU"), @@ -232,8 +234,6 @@ def _from_pretrained( # Load model from hub else: model_file_names = {"encoder": encoder_file_name, "decoder": decoder_file_name} - if use_cache: - model_file_names["decoder_with_past"] = decoder_with_past_file_name # If not ONNX then OpenVINO IR : adds binary files if not from_onnx: @@ -257,7 +257,21 @@ def _from_pretrained( if not compile_only: encoder = cls.load_model(file_names["encoder"], quantization_config) decoder = cls.load_model(file_names["decoder"], quantization_config) - if use_cache: + if use_cache and not model_has_state(decoder): + model_file_names["decoder_with_past"] = decoder_with_past_file_name + model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") + for name in ["decoder_with_past", "decoder_with_past_bin"]: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) + file_names[name] = model_cache_path decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config) else: encoder = cls._compile_model( @@ -266,7 +280,21 @@ def _from_pretrained( decoder = cls._compile_model( file_names["decoder"], kwargs.get("device", "CPU"), kwargs.get("ov_config"), model_save_dir ) - if use_cache: + if use_cache and not model_has_state(decoder): + model_file_names["decoder_with_past"] = decoder_with_past_file_name + model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") + for name in ["decoder_with_past", "decoder_with_past_bin"]: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) + file_names[name] = model_cache_path decoder_with_past = cls._compile_model( file_names["decoder_with_past"], kwargs.get("device", "CPU"), @@ -291,7 +319,6 @@ def _from_pretrained( logger.info( "Generation config file not found, using a generation config created from the model config." ) - return cls( encoder=encoder, decoder=decoder, @@ -366,6 +393,8 @@ def _from_transformers( else: ov_config = OVConfig(dtype="fp32") + stateful = kwargs.get("stateful", True) + main_export( model_name_or_path=model_id, output=save_dir_path, @@ -378,6 +407,7 @@ def _from_transformers( force_download=force_download, trust_remote_code=trust_remote_code, ov_config=ov_config, + stateful=stateful, ) return cls._from_pretrained( @@ -400,7 +430,8 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng elif inputs.get_any_name().startswith("cache_position"): shapes[inputs][0] = sequence_length elif is_decoder and not inputs.get_any_name().startswith("encoder"): - shapes[inputs][1] = -1 + if not inputs.get_any_name().startswith("beam_idx"): + shapes[inputs][1] = -1 else: shapes[inputs][1] = sequence_length model.reshape(shapes) @@ -424,7 +455,7 @@ def reshape(self, batch_size: int, sequence_length: int): self.is_dynamic = True if batch_size == -1 and sequence_length == -1 else False self.encoder_model = self._reshape(self.encoder_model, batch_size, sequence_length, is_decoder=False) self.decoder_model = self._reshape(self.decoder_model, batch_size, sequence_length) - if self.use_cache: + if self.decoder_with_past_model is not None: self.decoder_with_past_model = self._reshape(self.decoder_with_past_model, batch_size, sequence_length) def half(self): @@ -439,7 +470,7 @@ def half(self): apply_moc_transformations(self.decoder_model, cf=False) compress_model_transformation(self.encoder_model) compress_model_transformation(self.decoder_model) - if self.use_cache: + if self.decoder_with_past_model is not None: apply_moc_transformations(self.decoder_with_past_model, cf=False) compress_model_transformation(self.decoder_with_past_model) return self diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 0ccf78a361..c4d9621b1b 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -34,10 +34,11 @@ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput - -from ..utils import is_transformers_version +from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE +from ...exporters.openvino.stateful import model_has_state from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties +from ..utils import is_transformers_version if is_transformers_version(">=", "4.43.0"): @@ -132,9 +133,7 @@ >>> from optimum.intel import {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - >>> pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer) - >>> text = "He never went out without a book under his arm, and he often came back with two." + >>> model = {model_class}.from_pretrained("{checkpoint}")Whisper >>> outputs = pipe(text) ``` """ @@ -329,7 +328,7 @@ def __init__( self.encoder = OVEncoder(self.encoder_model, parent_model=self) self.decoder = OVDecoder(self.decoder_model, parent_model=self) - if self.use_cache: + if self.use_cache and self.decoder_with_past_model is not None: self.decoder_with_past = OVDecoder(self.decoder_with_past_model, parent_model=self) if enable_compilation: self.compile() @@ -345,6 +344,19 @@ def __init__( def dtype(self) -> Optional[torch.dtype]: return self.encoder.dtype or self.decoder.dtype + def to(self, device: str): + if isinstance(device, str): + self._device = device.upper() + self.encoder._device = self._device + self.decoder._device = self._device + if self.use_cache and self.decoder_with_past is not None: + self.decoder_with_past._device = self._device + self.clear_requests() + else: + logger.debug(f"device must be of type {str} but got {type(device)} instead") + + return self + @add_start_docstrings_to_model_forward( SEQ2SEQ_MODEL_DOCSTRING.format("batch_size, sequence_length") + TRANSLATION_EXAMPLE.format( @@ -369,12 +381,14 @@ def forward( encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # Decode - if past_key_values is None or self.decoder_with_past is None: + if past_key_values is None or self.decoder.stateful: decoder_outputs = self.decoder( - input_ids=decoder_input_ids, + input_ids=decoder_input_ids[:, -1:] if past_key_values is not None else decoder_input_ids, encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, ) else: decoder_outputs = self.decoder_with_past( @@ -414,16 +428,8 @@ def prepare_inputs_for_generation( def get_encoder(self): return self.encoder - # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache - @staticmethod - def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: - reordered_past = () - for layer_past in past: - # Cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past + def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: + self.decoder._reorder_cache(past, beam_idx) def reshape(self, batch_size: int, sequence_length: int): """ @@ -458,13 +464,13 @@ def clear_requests(self): ) self.encoder.request = None self.decoder.request = None - if self.use_cache: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past.request = None def compile(self): self.encoder._compile() self.decoder._compile() - if self.use_cache: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past._compile() @@ -575,15 +581,15 @@ def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2Se self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs} self.key_value_output_names = [key for key in self.output_names if "key_values" in key or "present" in key] + self.stateful = model_has_state(self.model) is_legacy = any("past_key_values" in key.get_any_name() for key in self.model.outputs) + self.use_past = len(self.key_value_input_names) > 0 or self.stateful + self.next_beam_idx = None if len(self.key_value_input_names) > 0 and not is_legacy: - self.use_past = True self.num_pkv = 2 else: - self.use_past = False self.num_pkv = 4 - self.request = None if not self._compile_only else self.model.create_infer_request() @property @@ -622,7 +628,11 @@ def forward( # Model inputs inputs = {} - if past_key_values is not None: + if self.stateful and past_key_values is None: + self.request.reset_state() + self._past_len = 0 + + if past_key_values is not None and not self.stateful: # Flatten the past_key_values past_key_values = tuple( past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer @@ -644,17 +654,28 @@ def forward( if "decoder_attention_mask" in self.input_names and decoder_attention_mask is not None: inputs["decoder_attention_mask"] = decoder_attention_mask - if "cache_position" in self.input_names and cache_position is not None: + if "cache_position" in self.input_names: + if cache_position is None: + cache_position = torch.arange(self._past_len, self._past_len + input_ids.shape[1]) inputs["cache_position"] = cache_position + if "beam_idx" in self.input_names: + batch_size = input_ids.shape[0] + inputs["beam_idx"] = ( + self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=np.int32) + ) # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the - # self-attention layer and 2 to the cross-attention layer) - out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) + self._past_len += input_ids.shape[1] + + out_past_key_values = () + if not self.stateful: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the + # self-attention layer and 2 to the cross-attention layer) + out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) # Tuple of tuple of length `n_layers`, with each tuple of length equal to: # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) @@ -664,11 +685,12 @@ def forward( out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) ) else: - # grab the cross attention key/values from the inputs - out_past_key_values = tuple( - out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] - for i in range(0, len(out_past_key_values), self.num_pkv) - ) + if not self.stateful: + # grab the cross attention key/values from the inputs + out_past_key_values = tuple( + out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] + for i in range(0, len(out_past_key_values), self.num_pkv) + ) return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values) @@ -694,6 +716,26 @@ def _compile(self): logger.info(f"{self._device} SUPPORTED_PROPERTIES:") _print_compiled_model_properties(compiled_model) + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + if self.stateful: + self.next_beam_idx = np.array(beam_idx) + return past_key_values + else: + reordered_past = () + for layer_past in past_key_values: + # Cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + @add_start_docstrings( """ @@ -785,7 +827,9 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng if is_decoder: if inputs.get_any_name().startswith("past_key_values"): shapes[inputs][2] = -1 - elif not inputs.get_any_name().startswith("encoder"): + elif not inputs.get_any_name().startswith("encoder") and not inputs.get_any_name().startswith( + "beam_idx" + ): shapes[inputs][1] = -1 model.reshape(shapes) return model @@ -868,7 +912,9 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng if is_decoder: if inputs.get_any_name().startswith("past_key_values"): shapes[inputs][2] = -1 - elif not inputs.get_any_name().startswith("encoder"): + elif not inputs.get_any_name().startswith("encoder") and not inputs.get_any_name().startswith( + "beam_idx" + ): shapes[inputs][1] = -1 model.reshape(shapes) return model @@ -965,9 +1011,6 @@ class _OVModelForWhisper(OVModelForSpeechSeq2Seq, WhisperForConditionalGeneratio """ auto_model_class = WhisperForConditionalGeneration - - # force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods - prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation generate = WhisperForConditionalGeneration.generate @classmethod @@ -995,3 +1038,67 @@ def __init__(self, stride): # a dummy model attribute that's used in the generate method to compute the input stride # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] model = DummyWhisperModel() + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + decoder_attention_mask=None, + cache_position=None, + **kwargs, + ): + # Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time + # this function needs to be touched, let's try to sort out the commonalities between the two and remove the + # overwrite. + + decoder_position_ids = None + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) + + past_length = 0 + if past_key_values is not None: + if self.decoder.stateful: + past_length = getattr(self.decoder, "_past_len", 0) + else: + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + if decoder_position_ids is not None: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format) + + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device + ) + elif use_cache: + cache_position = cache_position[-decoder_input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + decoder_input_ids = decoder_input_ids.contiguous() + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "use_cache": use_cache, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "cache_position": cache_position, + } diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 67511bb845..d3f2d5b4ff 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -261,7 +261,7 @@ def test_exporters_cli_int8(self, task: str, model_type: str): if task.startswith("text2text-generation"): models = [model.encoder, model.decoder] - if task.endswith("with-past"): + if task.endswith("with-past") and not model.decoder.stateful: models.append(model.decoder_with_past) elif model_type.startswith("stable-diffusion") or model_type.startswith("flux"): models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder] diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f7f677bf8c..c7a70f8139 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1525,7 +1525,8 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_model.encoder, OVEncoder) self.assertIsInstance(ov_model.decoder, OVDecoder) - self.assertIsInstance(ov_model.decoder_with_past, OVDecoder) + self.assertTrue(ov_model.decoder.stateful) + self.assertTrue(ov_model.decoder_with_past is None) self.assertIsInstance(ov_model.config, PretrainedConfig) transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 2869acf834..f74d949188 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -566,7 +566,7 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust self.assertEqual(model._openvino_config.dtype, "int8") if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder, model.decoder_with_past] + models = [model.encoder, model.decoder] elif model.export_feature == "text-to-image": models = [model.unet, model.vae_encoder, model.vae_decoder] models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2) @@ -706,8 +706,8 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type, tru MODEL_NAMES[model_type], export=True, load_in_8bit=False, trust_remote_code=trust_remote_code ) if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder, model.decoder_with_past] - elif model.export_feature == "text-to-image": + models = [model.encoder, model.decoder] + elif model.export_feature.startswith("text-to-image"): models = [model.unet, model.vae_encoder, model.vae_decoder] models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2) elif model_type == "open-clip": @@ -1126,11 +1126,12 @@ def _generate_random_audio_data(processor): @parameterized.expand(itertools.product(MODEL_ID, APPLY_CACHING)) def test_calibration_data_uniqueness(self, model_id, apply_caching): ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True) + self.assertTrue(ov_model.decoder_with_past is None) processor = AutoProcessor.from_pretrained(model_id) calibration_data = [] - ov_model.decoder_with_past.request = InferRequestWrapper( - ov_model.decoder_with_past.request, calibration_data, apply_caching=apply_caching + ov_model.decoder.request = InferRequestWrapper( + ov_model.decoder.request, calibration_data, apply_caching=apply_caching ) for _ in range(2): input_features = self._generate_random_audio_data(processor)