From be65aaaf47811ee95c68a9c90af63b1e0dfff5a5 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 19 Jun 2024 20:25:05 +0400 Subject: [PATCH 01/11] run stateful whisper --- .../intel/openvino/modeling_base_seq2seq.py | 17 ++-- optimum/intel/openvino/modeling_seq2seq.py | 89 +++++++++++++------ 2 files changed, 72 insertions(+), 34 deletions(-) diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 0ce15641fe..d7f487ce34 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -26,6 +26,7 @@ from ...exporters.openvino import main_export from ..utils.import_utils import is_transformers_version +from ...exporters.openvino.stateful import model_has_state from .configuration import OVConfig, OVWeightQuantizationConfig from .modeling_base import OVBaseModel from .utils import ( @@ -64,7 +65,7 @@ def __init__( **kwargs, ): self.config = config - self.use_cache = decoder_with_past is not None + self.use_cache = decoder_with_past is not None or model_has_state(decoder) self.model_save_dir = model_save_dir self._compile_only = kwargs.get("compile_only", False) self._device = device.upper() @@ -75,7 +76,8 @@ def __init__( if self.is_dynamic and not self._compile_only: encoder = self._reshape(encoder, -1, -1, is_decoder=False) decoder = self._reshape(decoder, -1, -1) - decoder_with_past = self._reshape(decoder_with_past, -1, -1) if self.use_cache else None + if decoder_with_past is not None: + decoder_with_past = self._reshape(decoder_with_past, -1, -1) if self.use_cache else None self.encoder_model = encoder self.decoder_model = decoder self.decoder_with_past_model = decoder_with_past @@ -204,7 +206,7 @@ def _from_pretrained( if not compile_only: encoder = cls.load_model(os.path.join(model_id, encoder_file_name), quantization_config) decoder = cls.load_model(os.path.join(model_id, decoder_file_name), quantization_config) - if use_cache: + if use_cache and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)): decoder_with_past = cls.load_model( os.path.join(model_id, decoder_with_past_file_name), quantization_config ) @@ -221,7 +223,7 @@ def _from_pretrained( kwargs.get("ov_config"), model_save_dir, ) - if use_cache: + if use_cache and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)): decoder_with_past = cls._compile_model( os.path.join(model_id, decoder_with_past_file_name), kwargs.get("device", "CPU"), @@ -400,7 +402,8 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng elif inputs.get_any_name().startswith("cache_position"): shapes[inputs][0] = sequence_length elif is_decoder and not inputs.get_any_name().startswith("encoder"): - shapes[inputs][1] = -1 + if not inputs.get_any_name().startswith("beam_idx"): + shapes[inputs][1] = -1 else: shapes[inputs][1] = sequence_length model.reshape(shapes) @@ -424,7 +427,7 @@ def reshape(self, batch_size: int, sequence_length: int): self.is_dynamic = True if batch_size == -1 and sequence_length == -1 else False self.encoder_model = self._reshape(self.encoder_model, batch_size, sequence_length, is_decoder=False) self.decoder_model = self._reshape(self.decoder_model, batch_size, sequence_length) - if self.use_cache: + if self.decoder_with_past_model is not None: self.decoder_with_past_model = self._reshape(self.decoder_with_past_model, batch_size, sequence_length) def half(self): @@ -439,7 +442,7 @@ def half(self): apply_moc_transformations(self.decoder_model, cf=False) compress_model_transformation(self.encoder_model) compress_model_transformation(self.decoder_model) - if self.use_cache: + if self.decoder_with_past_model is not None: apply_moc_transformations(self.decoder_with_past_model, cf=False) compress_model_transformation(self.decoder_with_past_model) return self diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 0ccf78a361..14880ae1ec 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -34,6 +34,8 @@ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE +from ...exporters.openvino.stateful import model_has_state from ..utils import is_transformers_version from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM @@ -132,9 +134,7 @@ >>> from optimum.intel import {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}") - >>> pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer) - >>> text = "He never went out without a book under his arm, and he often came back with two." + >>> model = {model_class}.from_pretrained("{checkpoint}")Whisper >>> outputs = pipe(text) ``` """ @@ -329,7 +329,7 @@ def __init__( self.encoder = OVEncoder(self.encoder_model, parent_model=self) self.decoder = OVDecoder(self.decoder_model, parent_model=self) - if self.use_cache: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past = OVDecoder(self.decoder_with_past_model, parent_model=self) if enable_compilation: self.compile() @@ -345,6 +345,19 @@ def __init__( def dtype(self) -> Optional[torch.dtype]: return self.encoder.dtype or self.decoder.dtype + def to(self, device: str): + if isinstance(device, str): + self._device = device.upper() + self.encoder._device = self._device + self.decoder._device = self._device + if self.use_cache and self.decoder_with_past_model is not None: + self.decoder_with_past._device = self._device + self.clear_requests() + else: + logger.debug(f"device must be of type {str} but got {type(device)} instead") + + return self + @add_start_docstrings_to_model_forward( SEQ2SEQ_MODEL_DOCSTRING.format("batch_size, sequence_length") + TRANSLATION_EXAMPLE.format( @@ -371,10 +384,11 @@ def forward( # Decode if past_key_values is None or self.decoder_with_past is None: decoder_outputs = self.decoder( - input_ids=decoder_input_ids, + input_ids=decoder_input_ids[:, -1:] if past_key_values is not None else decoder_input_ids, encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + past_key_values = past_key_values ) else: decoder_outputs = self.decoder_with_past( @@ -414,16 +428,8 @@ def prepare_inputs_for_generation( def get_encoder(self): return self.encoder - # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache - @staticmethod - def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: - reordered_past = () - for layer_past in past: - # Cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past + def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: + self.decoder._reorder_cache(past, beam_idx) def reshape(self, batch_size: int, sequence_length: int): """ @@ -458,13 +464,13 @@ def clear_requests(self): ) self.encoder.request = None self.decoder.request = None - if self.use_cache: + if self.use_cache and self.decoder_with_past_model is not None: self.decoder_with_past.request = None def compile(self): self.encoder._compile() self.decoder._compile() - if self.use_cache: + if self.use_cache and self.decoder_with_past_model is not None: self.decoder_with_past._compile() @@ -477,7 +483,7 @@ class OVEncoder: The OpenVINO inference request associated to the encoder. """ - def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM): + def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM, merged=False): self.model = model self.parent_model = parent_model self._comple_only = parent_model._compile_only @@ -575,15 +581,15 @@ def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2Se self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs} self.key_value_output_names = [key for key in self.output_names if "key_values" in key or "present" in key] + self.stateful = model_has_state(self.model) is_legacy = any("past_key_values" in key.get_any_name() for key in self.model.outputs) + self.use_past = len(self.key_value_input_names) > 0 or self.stateful + self.next_beam_idx = None if len(self.key_value_input_names) > 0 and not is_legacy: - self.use_past = True self.num_pkv = 2 else: - self.use_past = False self.num_pkv = 4 - self.request = None if not self._compile_only else self.model.create_infer_request() @property @@ -622,7 +628,10 @@ def forward( # Model inputs inputs = {} - if past_key_values is not None: + if self.stateful and past_key_values is None: + self.request.reset_state() + + if past_key_values is not None and not self.stateful: # Flatten the past_key_values past_key_values = tuple( past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer @@ -647,6 +656,11 @@ def forward( if "cache_position" in self.input_names and cache_position is not None: inputs["cache_position"] = cache_position + if "beam_idx" in self.input_names: + batch_size = input_ids.shape[0] + inputs["beam_idx"] = ( + self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) + ) # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() @@ -664,11 +678,12 @@ def forward( out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) ) else: - # grab the cross attention key/values from the inputs - out_past_key_values = tuple( - out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] - for i in range(0, len(out_past_key_values), self.num_pkv) - ) + if not self.stateful: + # grab the cross attention key/values from the inputs + out_past_key_values = tuple( + out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] + for i in range(0, len(out_past_key_values), self.num_pkv) + ) return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values) @@ -694,6 +709,26 @@ def _compile(self): logger.info(f"{self._device} SUPPORTED_PROPERTIES:") _print_compiled_model_properties(compiled_model) + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + if self.stateful: + self.next_beam_idx = np.array(beam_idx) + return past_key_values + else: + reordered_past = () + for layer_past in past_key_values: + # Cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + @add_start_docstrings( """ From 81e01ee2d0df1cc98daa5681e815141dcdc05ccd Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 24 Jun 2024 22:39:37 +0400 Subject: [PATCH 02/11] fix conflict --- notebooks/ipex/text_generation.ipynb | 8 +- .../openvino/optimum_openvino_inference.ipynb | 4 +- .../openvino/quantized_generation_demo.ipynb | 86 +++++++++---------- ...stable_diffusion_hybrid_quantization.ipynb | 25 +++--- optimum/exporters/openvino/convert.py | 70 +++++++++++++-- optimum/exporters/openvino/model_configs.py | 72 +++++++++++++++- optimum/exporters/openvino/stateful.py | 84 +++++++++++++++++- 7 files changed, 278 insertions(+), 71 deletions(-) diff --git a/notebooks/ipex/text_generation.ipynb b/notebooks/ipex/text_generation.ipynb index d1a62d9201..df46355531 100644 --- a/notebooks/ipex/text_generation.ipynb +++ b/notebooks/ipex/text_generation.ipynb @@ -62,9 +62,13 @@ "source": [ "model = IPEXModelForCausalLM.from_pretrained(\"gpt2\", torch_dtype=torch.bfloat16, export=True)\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", - "input_sentence = [\"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?\"]\n", + "input_sentence = [\n", + " \"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?\"\n", + "]\n", "model_inputs = tokenizer(input_sentence, return_tensors=\"pt\")\n", - "generation_kwargs = dict(max_new_tokens=32, do_sample=False, num_beams=4, num_beam_groups=1, no_repeat_ngram_size=2, use_cache=True)\n", + "generation_kwargs = dict(\n", + " max_new_tokens=32, do_sample=False, num_beams=4, num_beam_groups=1, no_repeat_ngram_size=2, use_cache=True\n", + ")\n", "\n", "generated_ids = model.generate(**model_inputs, **generation_kwargs)\n", "output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n", diff --git a/notebooks/openvino/optimum_openvino_inference.ipynb b/notebooks/openvino/optimum_openvino_inference.ipynb index 76c77aec55..7ef14e0635 100644 --- a/notebooks/openvino/optimum_openvino_inference.ipynb +++ b/notebooks/openvino/optimum_openvino_inference.ipynb @@ -466,7 +466,9 @@ "source": [ "# Set the device directly with `.from_pretrained()`\n", "if \"GPU\" in Core().available_devices:\n", - " model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp16\", device=\"GPU\")" + " model = OVModelForQuestionAnswering.from_pretrained(\n", + " \"distilbert-base-uncased-distilled-squad-ov-fp16\", device=\"GPU\"\n", + " )" ] }, { diff --git a/notebooks/openvino/quantized_generation_demo.ipynb b/notebooks/openvino/quantized_generation_demo.ipynb index 5673243cb2..cc5c1ec2b3 100644 --- a/notebooks/openvino/quantized_generation_demo.ipynb +++ b/notebooks/openvino/quantized_generation_demo.ipynb @@ -121,7 +121,7 @@ " \"CACHE_DIR\": os.path.join(save_name, \"model_cache\"), # OpenVINO will use this directory as cache\n", " },\n", " \"compile\": False,\n", - " \"quantization_config\": quantization_config\n", + " \"quantization_config\": quantization_config,\n", "}\n", "\n", "# Check whether the model was already exported\n", @@ -143,8 +143,8 @@ "\n", "# TODO Optional: export to huggingface/hub\n", "\n", - "model_size = os.stat(os.path.join(save_name, \"openvino_model.bin\")).st_size / 1024 ** 3\n", - "print(f'Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB')" + "model_size = os.stat(os.path.join(save_name, \"openvino_model.bin\")).st_size / 1024**3\n", + "print(f\"Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB\")" ] }, { @@ -212,7 +212,7 @@ "from transformers import TextStreamer\n", "\n", "# Tokenize the sample\n", - "inputs = tokenizer([sample], return_tensors='pt')\n", + "inputs = tokenizer([sample], return_tensors=\"pt\")\n", "\n", "# Call generate on the inputs\n", "out = model.generate(\n", @@ -294,7 +294,7 @@ "\n", "\n", "# Tokenize the sample\n", - "inputs = tokenizer([sample], return_tensors='pt') \n", + "inputs = tokenizer([sample], return_tensors=\"pt\")\n", "\n", "out = stateless_model.generate(\n", " **inputs,\n", @@ -302,7 +302,7 @@ " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n", " pad_token_id=tokenizer.eos_token_id,\n", " prompt_lookup_num_tokens=3,\n", - ") " + ")" ] }, { @@ -358,7 +358,7 @@ " \"CACHE_DIR\": os.path.join(save_name, \"model_cache\"), # OpenVINO will use this directory as cache\n", " },\n", " \"compile\": False,\n", - " \"quantization_config\": quantization_config\n", + " \"quantization_config\": quantization_config,\n", "}\n", "\n", "# Check whether the model was already exported\n", @@ -458,15 +458,15 @@ " if len(self.seq_lens) > 0 or len(self.win_sizes) > 0:\n", " raise RuntimeError(\"Always use a new instance, don't reuse!\")\n", " self.model_forward = self.model.forward\n", - " \n", + "\n", " @wraps(self.model_forward)\n", " def forward_wrapper(**kwargs):\n", " self.seq_lens[-1].append(kwargs.get(\"attention_mask\").shape[-1])\n", " self.win_sizes[-1].append(kwargs.get(\"input_ids\").shape[-1] - 1)\n", " return self.model_forward(**kwargs)\n", - " \n", + "\n", " self.model.forward = forward_wrapper\n", - " \n", + "\n", " # wrap generate method\n", " self.model_generate = self.model.generate\n", "\n", @@ -479,10 +479,11 @@ " out = self.model_generate(*args, **kwargs)\n", " self.seq_lens[-1].append(out.shape[-1])\n", " return out\n", + "\n", " self.model.generate = generate_wrapper\n", " return self\n", "\n", - " def __exit__(self, type, value, traceback):\n", + " def __exit__(self, type, value, traceback):\n", " self.model.forward = self.model_forward\n", " self.model.generate = self.model_generate\n", " self.model_forward = None\n", @@ -494,7 +495,7 @@ " self.seq_lens = [sl[1:] for sl in self.seq_lens]\n", " # Add window size for output to ease calculation later\n", " for ws, sl in zip(self.win_sizes, self.seq_lens):\n", - " ws.append(0) \n", + " ws.append(0)\n", "\n", " def acceptance_rate(self, return_mean=True, normalize=False):\n", " # ar_per_win = ((cur_seq_len - cur_win_size) - (prev_seq_len - prev_win_size) - 1) / prev_win_size\n", @@ -503,9 +504,8 @@ " sl = np.array(sl, dtype=np.float64)\n", " ws = np.array(ws, dtype=np.float64)\n", " out_lens = sl - ws\n", - " accepted = (out_lens[1:] - out_lens[:-1] - 1)\n", - " ar_per_win.append(np.divide(accepted, ws[:-1],\n", - " out=np.zeros_like(accepted),where=ws[:-1] != 0))\n", + " accepted = out_lens[1:] - out_lens[:-1] - 1\n", + " ar_per_win.append(np.divide(accepted, ws[:-1], out=np.zeros_like(accepted), where=ws[:-1] != 0))\n", " ar_per_win = np.hstack(ar_per_win)\n", " # Normalized AR doesn't take into account windows with size 0\n", " if normalize:\n", @@ -544,7 +544,7 @@ "samples_number = 30\n", "with AcceptanceRateRecorder(stateless_model) as ar_recorder:\n", " for text in tqdm(dataset[:samples_number]):\n", - " tokenized_prompt = tokenizer([prompt_template.format(text=text)], return_tensors='pt')\n", + " tokenized_prompt = tokenizer([prompt_template.format(text=text)], return_tensors=\"pt\")\n", " stateless_model.generate(\n", " **tokenized_prompt,\n", " max_new_tokens=128,\n", @@ -623,7 +623,6 @@ " return False\n", "\n", "\n", - "\n", "# Set the chat template to the tokenizer. The chat template implements the simple template of\n", "# User: content\n", "# Assistant: content\n", @@ -651,11 +650,7 @@ " if model_msg:\n", " messages.append({\"role\": \"Assistant\", \"content\": model_msg})\n", " input_token = tokenizer.apply_chat_template(\n", - " messages,\n", - " add_generation_prompt=True,\n", - " tokenize=True,\n", - " return_tensors=\"pt\",\n", - " return_dict=True\n", + " messages, add_generation_prompt=True, tokenize=True, return_tensors=\"pt\", return_dict=True\n", " )\n", " return input_token\n", "\n", @@ -679,18 +674,18 @@ " # Construct the input message string for the model by concatenating the current system message and conversation history\n", " # Tokenize the messages string\n", " inputs = prepare_history_for_model(history)\n", - " input_length = inputs['input_ids'].shape[1]\n", + " input_length = inputs[\"input_ids\"].shape[1]\n", " # truncate input in case it is too long.\n", " # TODO improve this\n", " if input_length > 2000:\n", " history = [history[-1]]\n", " inputs = prepare_history_for_model(history)\n", - " input_length = inputs['input_ids'].shape[1]\n", + " input_length = inputs[\"input_ids\"].shape[1]\n", "\n", " prompt_char = \"▌\"\n", " history[-1][1] = prompt_char\n", " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", - " \n", + "\n", " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", "\n", " # Create a stopping criteria to prevent the model from playing the role of the user aswell.\n", @@ -706,11 +701,14 @@ " eos_token_id=[tokenizer.eos_token_id],\n", " pad_token_id=tokenizer.eos_token_id,\n", " )\n", - " generate_kwargs = dict(\n", - " streamer=streamer,\n", - " generation_config=generation_config,\n", - " stopping_criteria=stopping_criteria,\n", - " ) | inputs\n", + " generate_kwargs = (\n", + " dict(\n", + " streamer=streamer,\n", + " generation_config=generation_config,\n", + " stopping_criteria=stopping_criteria,\n", + " )\n", + " | inputs\n", + " )\n", "\n", " if assisted:\n", " target_generate = stateless_model.generate\n", @@ -737,7 +735,7 @@ " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", " history[-1][1] = partial_text\n", " generation_time = time.perf_counter() - start\n", - " yield history, f'Generation time: {generation_time:.2f} sec', *([gr.update(interactive=True)] * 4)" + " yield history, f\"Generation time: {generation_time:.2f} sec\", *([gr.update(interactive=True)] * 4)" ] }, { @@ -781,7 +779,9 @@ " [\"Can you explain to me briefly what is Python programming language?\"],\n", " [\"Explain the plot of Cinderella in a sentence.\"],\n", " [\"Write a Python function to perform binary search over a sorted list. Use markdown to write code\"],\n", - " [\"Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?\"],\n", + " [\n", + " \"Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?\"\n", + " ],\n", "]\n", "\n", "\n", @@ -797,7 +797,7 @@ " \"\"\"\n", " # Append current user message to history with a blank assistant message which will be generated by the model\n", " history.append([message, None])\n", - " return ('', history)\n", + " return (\"\", history)\n", "\n", "\n", "def prepare_for_regenerate(history):\n", @@ -808,7 +808,7 @@ " history: conversation history\n", " Returns:\n", " updated history\n", - " \"\"\" \n", + " \"\"\"\n", " history[-1][1] = None\n", " return history\n", "\n", @@ -821,7 +821,7 @@ " msg = gr.Textbox(placeholder=\"Enter message here...\", show_label=False, autofocus=True, scale=75)\n", " status = gr.Textbox(\"Status: Idle\", show_label=False, max_lines=1, scale=15)\n", " with gr.Row():\n", - " submit = gr.Button(\"Submit\", variant='primary')\n", + " submit = gr.Button(\"Submit\", variant=\"primary\")\n", " regenerate = gr.Button(\"Regenerate\")\n", " clear = gr.Button(\"Clear\")\n", " with gr.Accordion(\"Advanced Options:\", open=False):\n", @@ -860,9 +860,7 @@ " step=0.1,\n", " interactive=True,\n", " )\n", - " gr.Examples(\n", - " EXAMPLES, inputs=msg, label=\"Click on any example and press the 'Submit' button\"\n", - " )\n", + " gr.Examples(EXAMPLES, inputs=msg, label=\"Click on any example and press the 'Submit' button\")\n", "\n", " # Sets generate function to be triggered when the user submit a new message\n", " gr.on(\n", @@ -876,20 +874,14 @@ " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", " concurrency_limit=1,\n", - " queue=True\n", - " )\n", - " regenerate.click(\n", - " fn=prepare_for_regenerate,\n", - " inputs=chatbot,\n", - " outputs=chatbot,\n", " queue=True,\n", - " concurrency_limit=1\n", - " ).then(\n", + " )\n", + " regenerate.click(fn=prepare_for_regenerate, inputs=chatbot, outputs=chatbot, queue=True, concurrency_limit=1).then(\n", " fn=generate,\n", " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", " concurrency_limit=1,\n", - " queue=True\n", + " queue=True,\n", " )\n", " clear.click(fn=lambda: (None, \"Status: Idle\"), inputs=None, outputs=[chatbot, status], queue=False)" ] diff --git a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb index 8ef2e8ad6c..d89457bd78 100644 --- a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb +++ b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb @@ -167,6 +167,7 @@ "def preprocess_fn(example):\n", " return {\"prompt\": example[\"caption\"]}\n", "\n", + "\n", "NUM_SAMPLES = 200\n", "dataset = dataset.take(NUM_SAMPLES)\n", "calibration_dataset = dataset.map(lambda x: preprocess_fn(x), remove_columns=dataset.column_names)" @@ -1066,12 +1067,14 @@ ], "source": [ "int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True)\n", - "quantization_config = OVWeightQuantizationConfig(bits=8, num_samples=NUM_SAMPLES, quant_method=OVQuantizationMethod.HYBRID)\n", + "quantization_config = OVWeightQuantizationConfig(\n", + " bits=8, num_samples=NUM_SAMPLES, quant_method=OVQuantizationMethod.HYBRID\n", + ")\n", "quantizer = OVQuantizer(int8_pipe)\n", "quantizer.quantize(\n", " ov_config=OVConfig(quantization_config=quantization_config),\n", " calibration_dataset=calibration_dataset,\n", - " save_directory=int8_model_path\n", + " save_directory=int8_model_path,\n", ")" ] }, @@ -1202,8 +1205,10 @@ " im_w, im_h = fp32_img.size\n", " is_horizontal = im_h <= im_w\n", " figsize = (20, 30) if is_horizontal else (30, 20)\n", - " fig, axs = plt.subplots(1 if is_horizontal else 2, 2 if is_horizontal else 1, figsize=figsize, sharex='all', sharey='all')\n", - " fig.patch.set_facecolor('white')\n", + " fig, axs = plt.subplots(\n", + " 1 if is_horizontal else 2, 2 if is_horizontal else 1, figsize=figsize, sharex=\"all\", sharey=\"all\"\n", + " )\n", + " fig.patch.set_facecolor(\"white\")\n", " list_axes = list(axs.flat)\n", " for a in list_axes:\n", " a.set_xticklabels([])\n", @@ -1217,7 +1222,7 @@ " img2_title = \"INT8 result\"\n", " list_axes[0].set_title(img1_title, fontsize=20)\n", " list_axes[1].set_title(img2_title, fontsize=20)\n", - " fig.subplots_adjust(wspace=0.0 if is_horizontal else 0.01 , hspace=0.01 if is_horizontal else 0.0)\n", + " fig.subplots_adjust(wspace=0.0 if is_horizontal else 0.01, hspace=0.01 if is_horizontal else 0.0)\n", " fig.tight_layout()" ] }, @@ -1230,13 +1235,10 @@ "source": [ "prompt = \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\"\n", "\n", + "\n", "def generate_image(pipeline, prompt):\n", " transformers.set_seed(1)\n", - " return pipeline(\n", - " prompt=prompt,\n", - " guidance_scale=8.0,\n", - " output_type=\"pil\"\n", - " ).images[0]" + " return pipeline(prompt=prompt, guidance_scale=8.0, output_type=\"pil\").images[0]" ] }, { @@ -1329,7 +1331,7 @@ "def get_model_size(model_folder, framework):\n", " \"\"\"\n", " Return OpenVINO or PyTorch model size in Mb.\n", - " \n", + "\n", " Arguments:\n", " model_folder:\n", " Directory containing a model.\n", @@ -1531,6 +1533,7 @@ "def get_val_dataset(num_items=3):\n", " return [item[\"caption\"] for item in dataset.take(num_items)]\n", "\n", + "\n", "def benchmark(pipeline, dataset):\n", " \"\"\"\n", " Benchmark PyTorch or OpenVINO model. This function does inference on `num_items`\n", diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index fdcfbecf53..4d48ff8455 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -36,9 +36,7 @@ from optimum.exporters.utils import ( _get_submodels_and_export_configs as _default_get_submodels_and_export_configs, ) -from optimum.exporters.utils import ( - get_diffusion_models_for_export, -) +from optimum.exporters.utils import ENCODER_NAME, DECODER_NAME, DECODER_WITH_PAST_NAME, _get_submodels_for_export_encoder_decoder from optimum.intel.utils.import_utils import ( _diffusers_version, _nncf_version, @@ -534,6 +532,8 @@ def export_models( f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export." ) + if not isinstance(stateful, (list, tuple)): + stateful = [stateful] * len(models_and_export_configs) for i, model_name in enumerate(models_and_export_configs.keys()): submodel, sub_export_config = models_and_export_configs[model_name] output_name = output_names[i] if output_names is not None else Path(model_name + ".xml") @@ -549,9 +549,9 @@ def export_models( input_shapes=input_shapes, model_kwargs=model_kwargs, ov_config=ov_config, - stateful=stateful[i] if isinstance(stateful, (list, tuple)) else stateful, patch_16bit_model=patch_16bit_model, library_name=library_name, + stateful=stateful[i], ) ) @@ -615,7 +615,7 @@ def export_from_model( logger.info(f"Automatic task detection to: {task}.") stateful = stateful and ( - ensure_export_task_support_stateful(task) + ensure_export_task_support_stateful(task, getattr(getattr(model, "config", {}), "is_encoder_decoder", False)) or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", "")) ) # TODO: support onnx_config.py in the model repo @@ -646,6 +646,8 @@ def export_from_model( input_shapes[input_name] = ( kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] ) + + logging.disable(logging.INFO) if library_name == "open_clip": custom_architecture = True @@ -653,7 +655,17 @@ def export_from_model( model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels ) - if library_name == "diffusers": + elif (model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS)) and stateful and not custom_architecture: + export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( + model=model, + task=task, + preprocessors=preprocessors, + library_name=library_name, + _variant="default" + ) + stateful = [False, True] + + elif library_name == "diffusers": export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino") stateful_submodels = False else: @@ -895,6 +907,7 @@ def _add_version_info_to_model(model: Model, library_name: Optional[str] = None) return model + def _get_multi_modal_submodels_and_export_configs( model: Union["PreTrainedModel", "TFPreTrainedModel"], task: str, @@ -933,7 +946,6 @@ def _get_multi_modal_submodels_and_export_configs( stateful_parts.append(stateful if getattr(model_part_config, "use_past", False) else False) return main_config, models_for_export, stateful_parts - def _get_submodels_and_export_configs( model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"], task: str, @@ -1193,3 +1205,47 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype): models_for_export["text_encoder_2"] = (text_encoder_2, export_config) return models_for_export + + +def _get_encoder_decoder_stateful_models_for_export( + model: Union["PreTrainedModel", "TFPreTrainedModel"], + task: str, + _variant: str, + library_name: str, + int_dtype: str = "int64", + float_dtype: str = "fp32", + preprocessors: Optional[List[Any]] = None, +): + logger.info("HERE") + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="openvino", task=task, library_name=library_name + ) + export_config = export_config_constructor( + model.config, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=False, + ) + + export_config.variant = _variant + all_variants = "\n".join( + [f" - {name}: {description}" for name, description in export_config.VARIANTS.items()] + ) + logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}") + + models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True) + + encoder_export_config = export_config.with_behavior("encoder") + models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config) + + decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True, stateful=True) + decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME) + models_for_export[DECODER_NAME] = ( + decoder_with_past_model, + decoder_export_config_with_past, + ) + logger.info(models_for_export.keys()) + + + return None, models_for_export diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index b8310882ba..a3140a0da6 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -18,6 +18,10 @@ from packaging import version from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel +from optimum.exporters.onnx.base import ConfigBehavior +import logging +logger = logging.getLogger(__name__) + from transformers.utils import is_tf_available from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig @@ -38,10 +42,13 @@ PhiOnnxConfig, UNetOnnxConfig, VisionOnnxConfig, + VaeDecoderOnnxConfig, + VaeEncoderOnnxConfig, + WhisperOnnxConfig ) from optimum.exporters.onnx.model_patcher import ModelPatcher from optimum.exporters.tasks import TasksManager -from optimum.utils import DEFAULT_DUMMY_SHAPES +from optimum.utils import DEFAULT_DUMMY_SHAPES, DummyInputGenerator from optimum.utils.input_generators import ( DTYPE_MAPPER, DummyInputGenerator, @@ -2204,3 +2211,66 @@ def patch_model_for_export( if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS: return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs) return super().patch_model_for_export(model, model_kwargs) + + +@register_in_tasks_manager("whisper", *["feature-extraction", "feature-extraction-with-past", "audio-classification", "automatic-speech-recognition", "automatic-speech-recognition-with-past",], library_name="transformers") +class WhisperOpenVINOConfig(WhisperOnnxConfig): + def __init__(self, config: PretrainedConfig, task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, legacy: bool = False, stateful: bool = False): + self.stateful = stateful + logger.warn(f"config stateful: {self.stateful}") + super().__init__(config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy) + + + def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: + """ + Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`. + Each dummy input generator is independent, so this method instantiates the first generator, and + forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch + size. Override this method for custom behavior. + """ + logger.warn(f"config stateful: {self.stateful}") + if self.stateful: + if "encoder_sequence_length" not in kwargs: + sequence_len = kwargs.get("sequence_length", DEFAULT_DUMMY_SHAPES["sequence_length"]) + kwargs["encoder_sequence_length"] = sequence_len + 2 + logger.warn(kwargs) + return super()._create_dummy_input_generator_classes(**kwargs) + + def with_behavior( + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + stateful: bool = False + ) -> "OnnxSeq2SeqConfigWithPast": + """ + Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + use_past (`bool`, defaults to `False`): + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. + + Returns: + `OnnxSeq2SeqConfigWithPast` + """ + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + legacy=self.legacy, + stateful=stateful + ) + onnx_config.variant = self.variant + return onnx_config diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 4b4374ab51..476fc1b70d 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -191,9 +191,76 @@ def ensure_stateful_is_available(warn=True): return True -def ensure_export_task_support_stateful(task: str): +def ensure_export_task_support_stateful(task: str, is_encoder_decoder:bool = False): task = TasksManager.map_from_synonym(task) - return task in ["text-generation-with-past"] + if not is_encoder_decoder: + return task == "text-generation-with-past" + + _ENCODER_DECODER_TASKS_WITH_PAST = ( + "automatic-speech-recognition", + "document-question-answering", + "image-to-text", + "text2text-generation", + "visual-question-answering", + ) + + is_stateful = task.endswith("-with-past") and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + log.warn(f"is_stateful {is_stateful}") + return is_stateful + + +def remove_parameters_by_names(model: ov.Model, names: list): + parameters = [model.input(name).get_node() for name in names] + for p in parameters: + model.remove_parameter(p) + + +def get_input_nodes(node): + return [input.get_node() for input in node.input_values()] + + +def find_dependent_nodes(model: ov.Model, sources: list): + # Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources` + result = set(sources) + for node in model.get_ordered_ops(): + input_nodes = set(get_input_nodes(node)) + if input_nodes & result: + result.add(node) + return result + + +def get_read_value_ops(model: ov.Model): + return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"] + + +def get_consumer_nodes(node): + consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()]) + return set(input.get_node() for input in consumer_inputs) + + +def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list): + # Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's + other_inputs = set(model.get_parameters() + get_read_value_ops(model)) - set(sources) + other_nodes = find_dependent_nodes(model, other_inputs) + source_dependent_nodes = find_dependent_nodes(model, sources) + # TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph + nodes = source_dependent_nodes - other_nodes + edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes] + return edge_nodes + + +def insert_state_for_nodes(model: ov.Model, nodes): + # For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression + outputs = sum((node.outputs() for node in nodes), []) + for output in outputs: + consumers = output.get_target_inputs() + # FIXME: get_any_name is not reliable as tensor may not have any names + variable_id = output.get_any_name() + read_value = ov.runtime.opset13.read_value(output, variable_id) + for consumer in consumers: + consumer.replace_source_output(read_value.output(0)) + assign = ov.runtime.opset13.assign(read_value, variable_id) + model.add_sinks([assign]) def ensure_model_type_support_stateful(model_type: str): @@ -212,6 +279,12 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name openvino model """ + if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"): + return patch_stateful_encoder_decoder(config, ov_model) + return patch_stateful_decoder(config, ov_model) + + +def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model): key_value_input_names = [ key_name for key in ov_model.inputs for key_name in key.get_names() if "key_values" in key_name ] @@ -235,3 +308,10 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name make_stateful( ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None ) + + +def patch_stateful_encoder_decoder(config, ov_model): + encoder_key_value_input_names = [key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())] + remove_parameters_by_names(ov_model, encoder_key_value_input_names) + patch_stateful_decoder(config, ov_model) + insert_state_for_nodes(ov_model, find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()])) \ No newline at end of file From de02ae734d99639b39da8677d9648658df0a1ce6 Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 24 Jun 2024 22:48:54 +0400 Subject: [PATCH 03/11] export part poc --- optimum/exporters/openvino/model_configs.py | 5 ----- optimum/exporters/openvino/stateful.py | 1 - optimum/intel/openvino/modeling_base_seq2seq.py | 2 +- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index a3140a0da6..e8865888ba 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -19,8 +19,6 @@ from packaging import version from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel from optimum.exporters.onnx.base import ConfigBehavior -import logging -logger = logging.getLogger(__name__) from transformers.utils import is_tf_available @@ -2217,7 +2215,6 @@ def patch_model_for_export( class WhisperOpenVINOConfig(WhisperOnnxConfig): def __init__(self, config: PretrainedConfig, task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, legacy: bool = False, stateful: bool = False): self.stateful = stateful - logger.warn(f"config stateful: {self.stateful}") super().__init__(config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy) @@ -2228,12 +2225,10 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch size. Override this method for custom behavior. """ - logger.warn(f"config stateful: {self.stateful}") if self.stateful: if "encoder_sequence_length" not in kwargs: sequence_len = kwargs.get("sequence_length", DEFAULT_DUMMY_SHAPES["sequence_length"]) kwargs["encoder_sequence_length"] = sequence_len + 2 - logger.warn(kwargs) return super()._create_dummy_input_generator_classes(**kwargs) def with_behavior( diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 476fc1b70d..4aa0a35ab5 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -205,7 +205,6 @@ def ensure_export_task_support_stateful(task: str, is_encoder_decoder:bool = Fal ) is_stateful = task.endswith("-with-past") and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST - log.warn(f"is_stateful {is_stateful}") return is_stateful diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index d7f487ce34..08b44e258e 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -117,7 +117,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]): """ src_files = [self.encoder_model, self.decoder_model] dst_file_names = [OV_ENCODER_NAME, OV_DECODER_NAME] - if self.use_cache: + if self.decoder_with_past_model is not None: src_files.append(self.decoder_with_past_model) dst_file_names.append(OV_DECODER_WITH_PAST_NAME) From 3e775cd5fbf1d0289651986e0c2b971ffc40030d Mon Sep 17 00:00:00 2001 From: eaidova Date: Mon, 24 Jun 2024 22:59:08 +0400 Subject: [PATCH 04/11] use SDPA for whisper --- optimum/exporters/openvino/__main__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index dba4628d79..cb2019dd36 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -264,9 +264,6 @@ def main_export( f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." ) - if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: - loading_kwargs["attn_implementation"] = "eager" - # some models force flash_attn attention by default that does not support load model on cpu if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES: loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type] From 53e4e43d02a9f0305a949826ca5aadd5f180bca6 Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 25 Jun 2024 20:58:22 +0400 Subject: [PATCH 05/11] more seq2seq models --- optimum/exporters/openvino/__main__.py | 2 +- optimum/exporters/openvino/convert.py | 27 +- optimum/exporters/openvino/model_configs.py | 411 ++++++++++++++++++-- optimum/exporters/openvino/stateful.py | 17 +- optimum/intel/openvino/modeling_seq2seq.py | 14 +- 5 files changed, 401 insertions(+), 70 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index cb2019dd36..abf2e0e108 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -28,7 +28,6 @@ from openvino.runtime import Core, Type, save_model from optimum.exporters import TasksManager from optimum.exporters.onnx.base import OnnxConfig -from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from optimum.exporters.openvino.convert import export_from_model from optimum.intel.utils.import_utils import ( is_nncf_available, @@ -40,6 +39,7 @@ _infer_library_from_model_name_or_path, _OpenClipForZeroShotImageClassification, ) +from optimum.intel.utils.import_utils import is_openvino_tokenizers_available from optimum.utils.save_utils import maybe_load_preprocessors from .utils import _MAX_UNCOMPRESSED_SIZE, MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 4d48ff8455..c885123c51 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -35,8 +35,11 @@ from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx from optimum.exporters.utils import ( _get_submodels_and_export_configs as _default_get_submodels_and_export_configs, + DECODER_NAME, + DECODER_WITH_PAST_NAME, + ENCODER_NAME, + _get_submodels_for_export_encoder_decoder, ) -from optimum.exporters.utils import ENCODER_NAME, DECODER_NAME, DECODER_WITH_PAST_NAME, _get_submodels_for_export_encoder_decoder from optimum.intel.utils.import_utils import ( _diffusers_version, _nncf_version, @@ -613,7 +616,6 @@ def export_from_model( task = task + "-with-past" logger.info(f"Automatic task detection to: {task}.") - stateful = stateful and ( ensure_export_task_support_stateful(task, getattr(getattr(model, "config", {}), "is_encoder_decoder", False)) or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", "")) @@ -655,13 +657,13 @@ def export_from_model( model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels ) - elif (model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS)) and stateful and not custom_architecture: + elif ( + (task.startswith(TasksManager._ENCODER_DECODER_TASKS) and getattr(model.config, "is_encoder_decoder", False)) + and stateful + and not custom_architecture + ): export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( - model=model, - task=task, - preprocessors=preprocessors, - library_name=library_name, - _variant="default" + model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default" ) stateful = [False, True] @@ -1229,9 +1231,7 @@ def _get_encoder_decoder_stateful_models_for_export( ) export_config.variant = _variant - all_variants = "\n".join( - [f" - {name}: {description}" for name, description in export_config.VARIANTS.items()] - ) + all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()]) logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}") models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True) @@ -1239,7 +1239,9 @@ def _get_encoder_decoder_stateful_models_for_export( encoder_export_config = export_config.with_behavior("encoder") models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config) - decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True, stateful=True) + decoder_export_config_with_past = export_config.with_behavior( + "decoder", use_past=True, use_past_in_inputs=True, stateful=True + ) decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME) models_for_export[DECODER_NAME] = ( decoder_with_past_model, @@ -1247,5 +1249,4 @@ def _get_encoder_decoder_stateful_models_for_export( ) logger.info(models_for_export.keys()) - return None, models_for_export diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index e8865888ba..3db6cbf1d0 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -22,7 +22,12 @@ from transformers.utils import is_tf_available -from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig +from optimum.exporters.onnx.base import ConfigBehavior +from optimum.exporters.onnx.config import ( + OnnxSeq2SeqConfigWithPast, + TextDecoderOnnxConfig, + TextDecoderWithPositionIdsOnnxConfig, +) from optimum.exporters.onnx.model_configs import ( CLIPOnnxConfig, CLIPTextOnnxConfig, @@ -36,13 +41,15 @@ IBertOnnxConfig, LlamaOnnxConfig, MistralOnnxConfig, + M2M100OnnxConfig, MPTOnnxConfig, PhiOnnxConfig, + T5OnnxConfig, UNetOnnxConfig, VisionOnnxConfig, VaeDecoderOnnxConfig, VaeEncoderOnnxConfig, - WhisperOnnxConfig + WhisperOnnxConfig, ) from optimum.exporters.onnx.model_patcher import ModelPatcher from optimum.exporters.tasks import TasksManager @@ -2211,12 +2218,35 @@ def patch_model_for_export( return super().patch_model_for_export(model, model_kwargs) -@register_in_tasks_manager("whisper", *["feature-extraction", "feature-extraction-with-past", "audio-classification", "automatic-speech-recognition", "automatic-speech-recognition-with-past",], library_name="transformers") +@register_in_tasks_manager( + "whisper", + *[ + "feature-extraction", + "feature-extraction-with-past", + "audio-classification", + "automatic-speech-recognition", + "automatic-speech-recognition-with-past", + ], + library_name="transformers", +) class WhisperOpenVINOConfig(WhisperOnnxConfig): - def __init__(self, config: PretrainedConfig, task: str = "feature-extraction", int_dtype: str = "int64", float_dtype: str = "fp32", use_past: bool = False, use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, legacy: bool = False, stateful: bool = False): + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + legacy: bool = False, + stateful: bool = False, + ): self.stateful = stateful - super().__init__(config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy) - + super().__init__( + config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy + ) def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: """ @@ -2232,40 +2262,335 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene return super()._create_dummy_input_generator_classes(**kwargs) def with_behavior( - self, - behavior: Union[str, ConfigBehavior], - use_past: bool = False, - use_past_in_inputs: bool = False, - stateful: bool = False - ) -> "OnnxSeq2SeqConfigWithPast": - """ - Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. - - Args: - behavior ([`ConfigBehavior`]): - The behavior to use for the new instance. - use_past (`bool`, defaults to `False`): - Whether or not the ONNX config to instantiate is for a model using KV cache. - use_past_in_inputs (`bool`, defaults to `False`): - Whether the KV cache is to be passed as an input to the ONNX. - - Returns: - `OnnxSeq2SeqConfigWithPast` - """ - if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): - behavior = ConfigBehavior(behavior) - - onnx_config = self.__class__( - self._config, - task=self.task, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - use_past=use_past, - use_past_in_inputs=use_past_in_inputs, - behavior=behavior, - preprocessors=self._preprocessors, - legacy=self.legacy, - stateful=stateful - ) - onnx_config.variant = self.variant - return onnx_config + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + stateful: bool = False, + ) -> "OnnxSeq2SeqConfigWithPast": + """ + Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + use_past (`bool`, defaults to `False`): + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. + + Returns: + `OnnxSeq2SeqConfigWithPast` + """ + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + legacy=self.legacy, + stateful=stateful, + ) + onnx_config.variant = self.variant + return onnx_config + + +@register_in_tasks_manager( + "t5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class T5OpenVINOConfig(T5OnnxConfig): + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + legacy: bool = False, + stateful: bool = False, + ): + self.stateful = stateful + super().__init__( + config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy + ) + + def with_behavior( + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + stateful: bool = False, + ) -> "OnnxSeq2SeqConfigWithPast": + """ + Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + use_past (`bool`, defaults to `False`): + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. + + Returns: + `OnnxSeq2SeqConfigWithPast` + """ + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + legacy=self.legacy, + stateful=stateful, + ) + onnx_config.variant = self.variant + return onnx_config + + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( + self.task, self._normalized_config, **kwargs + ) + dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1]( + self.task, + self._normalized_config, + **kwargs, + ) + dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2]( + self.task, + self._normalized_config, + encoder_sequence_length=dummy_text_input_generator.sequence_length + if not self.stateful + else dummy_text_input_generator.sequence_length + 2, + **kwargs, + ) + dummy_inputs_generators = [ + dummy_text_input_generator, + dummy_decoder_text_input_generator, + dummy_seq2seq_past_key_values_generator, + ] + + return dummy_inputs_generators + + +@register_in_tasks_manager( + "mt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class MT5OpenVINOConfig(T5OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "longt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class LongT5OpenVINOConfig(T5OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "m2m-100", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class M2M100OpenVINOConfig(M2M100OnnxConfig): + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + legacy: bool = False, + stateful: bool = False, + ): + self.stateful = stateful + super().__init__( + config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy + ) + + def with_behavior( + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + stateful: bool = False, + ) -> "OnnxSeq2SeqConfigWithPast": + """ + Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. + + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + use_past (`bool`, defaults to `False`): + Whether or not the ONNX config to instantiate is for a model using KV cache. + use_past_in_inputs (`bool`, defaults to `False`): + Whether the KV cache is to be passed as an input to the ONNX. + + Returns: + `OnnxSeq2SeqConfigWithPast` + """ + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + legacy=self.legacy, + stateful=stateful, + ) + onnx_config.variant = self.variant + return onnx_config + + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( + self.task, self._normalized_config, **kwargs + ) + task = "feature-extraction" if self.task != "text-generation" else "text-generation" + dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1][task]( + self.task, self._normalized_config, **kwargs + ) + if self.task != "text-generation": + kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length + if self.stateful: + kwargs["encoder_sequence_length"] = kwargs["encoder_sequence_length"] + 2 + + dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task]( + self.task, self._normalized_config, **kwargs + ) + dummy_inputs_generators = [ + dummy_text_input_generator, + dummy_decoder_text_input_generator, + dummy_seq2seq_past_key_values_generator, + ] + + return dummy_inputs_generators + + +@register_in_tasks_manager( + "bart", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", + "question-answering", + ], + library_name="transformers", +) +class BartOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "mbart", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + "text-classification", + "question-answering", + ], + library_name="transformers", +) +class MBartOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "blenderbot", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class BlenderbotOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "blenderbot-small", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class BlenderbotSmallOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "marian", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class MarianOpenVINOConfig(M2M100OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "pegasus", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text2text-generation", + "text2text-generation-with-past", + ], + library_name="transformers", +) +class PegasusOpenVINOConfig(M2M100OnnxConfig): + pass diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 4aa0a35ab5..4482455a6c 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -191,11 +191,11 @@ def ensure_stateful_is_available(warn=True): return True -def ensure_export_task_support_stateful(task: str, is_encoder_decoder:bool = False): +def ensure_export_task_support_stateful(task: str, is_encoder_decoder: bool = False): task = TasksManager.map_from_synonym(task) if not is_encoder_decoder: return task == "text-generation-with-past" - + _ENCODER_DECODER_TASKS_WITH_PAST = ( "automatic-speech-recognition", "document-question-answering", @@ -206,7 +206,7 @@ def ensure_export_task_support_stateful(task: str, is_encoder_decoder:bool = Fal is_stateful = task.endswith("-with-past") and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST return is_stateful - + def remove_parameters_by_names(model: ov.Model, names: list): parameters = [model.input(name).get_node() for name in names] @@ -310,7 +310,14 @@ def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model): def patch_stateful_encoder_decoder(config, ov_model): - encoder_key_value_input_names = [key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())] + encoder_key_value_input_names = [ + key.get_any_name() + for key in ov_model.inputs + if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names()) + ] remove_parameters_by_names(ov_model, encoder_key_value_input_names) patch_stateful_decoder(config, ov_model) - insert_state_for_nodes(ov_model, find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()])) \ No newline at end of file + insert_state_for_nodes( + ov_model, + find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]), + ) diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 14880ae1ec..5012ea8697 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -36,10 +36,8 @@ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE from ...exporters.openvino.stateful import model_has_state - -from ..utils import is_transformers_version from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM -from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties +from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties, is_transformers_version if is_transformers_version(">=", "4.43.0"): @@ -350,7 +348,7 @@ def to(self, device: str): self._device = device.upper() self.encoder._device = self._device self.decoder._device = self._device - if self.use_cache and self.decoder_with_past_model is not None: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past._device = self._device self.clear_requests() else: @@ -388,7 +386,7 @@ def forward( encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, - past_key_values = past_key_values + past_key_values=past_key_values, ) else: decoder_outputs = self.decoder_with_past( @@ -429,7 +427,7 @@ def get_encoder(self): return self.encoder def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: - self.decoder._reorder_cache(past, beam_idx) + self.decoder._reorder_cache(past, beam_idx) def reshape(self, batch_size: int, sequence_length: int): """ @@ -464,13 +462,13 @@ def clear_requests(self): ) self.encoder.request = None self.decoder.request = None - if self.use_cache and self.decoder_with_past_model is not None: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past.request = None def compile(self): self.encoder._compile() self.decoder._compile() - if self.use_cache and self.decoder_with_past_model is not None: + if self.use_cache and self.decoder_with_past is not None: self.decoder_with_past._compile() From 250b00e190dfabf9be3a6b8d45000744769755c4 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 26 Jun 2024 09:41:42 +0400 Subject: [PATCH 06/11] small refactoring --- optimum/exporters/openvino/convert.py | 7 +- optimum/exporters/openvino/model_configs.py | 255 ++++++-------------- tests/openvino/test_modeling.py | 3 +- 3 files changed, 82 insertions(+), 183 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index c885123c51..a4c783312c 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -1218,7 +1218,6 @@ def _get_encoder_decoder_stateful_models_for_export( float_dtype: str = "fp32", preprocessors: Optional[List[Any]] = None, ): - logger.info("HERE") export_config_constructor = TasksManager.get_exporter_config_constructor( model=model, exporter="openvino", task=task, library_name=library_name ) @@ -1239,9 +1238,9 @@ def _get_encoder_decoder_stateful_models_for_export( encoder_export_config = export_config.with_behavior("encoder") models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config) - decoder_export_config_with_past = export_config.with_behavior( - "decoder", use_past=True, use_past_in_inputs=True, stateful=True - ) + decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) + + decoder_export_config_with_past.stateful = True decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME) models_for_export[DECODER_NAME] = ( decoder_with_past_model, diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 3db6cbf1d0..8cd8946a9a 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -19,12 +19,9 @@ from packaging import version from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel from optimum.exporters.onnx.base import ConfigBehavior - from transformers.utils import is_tf_available -from optimum.exporters.onnx.base import ConfigBehavior from optimum.exporters.onnx.config import ( - OnnxSeq2SeqConfigWithPast, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig, ) @@ -44,7 +41,9 @@ M2M100OnnxConfig, MPTOnnxConfig, PhiOnnxConfig, + Pix2StructOnnxConfig, T5OnnxConfig, + TrOCROnnxConfig, UNetOnnxConfig, VisionOnnxConfig, VaeDecoderOnnxConfig, @@ -2230,24 +2229,6 @@ def patch_model_for_export( library_name="transformers", ) class WhisperOpenVINOConfig(WhisperOnnxConfig): - def __init__( - self, - config: PretrainedConfig, - task: str = "feature-extraction", - int_dtype: str = "int64", - float_dtype: str = "fp32", - use_past: bool = False, - use_past_in_inputs: bool = False, - behavior: ConfigBehavior = ConfigBehavior.MONOLITH, - preprocessors: Optional[List[Any]] = None, - legacy: bool = False, - stateful: bool = False, - ): - self.stateful = stateful - super().__init__( - config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy - ) - def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: """ Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`. @@ -2255,51 +2236,12 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch size. Override this method for custom behavior. """ - if self.stateful: + if getattr(self, "stateful"): if "encoder_sequence_length" not in kwargs: sequence_len = kwargs.get("sequence_length", DEFAULT_DUMMY_SHAPES["sequence_length"]) kwargs["encoder_sequence_length"] = sequence_len + 2 return super()._create_dummy_input_generator_classes(**kwargs) - def with_behavior( - self, - behavior: Union[str, ConfigBehavior], - use_past: bool = False, - use_past_in_inputs: bool = False, - stateful: bool = False, - ) -> "OnnxSeq2SeqConfigWithPast": - """ - Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. - - Args: - behavior ([`ConfigBehavior`]): - The behavior to use for the new instance. - use_past (`bool`, defaults to `False`): - Whether or not the ONNX config to instantiate is for a model using KV cache. - use_past_in_inputs (`bool`, defaults to `False`): - Whether the KV cache is to be passed as an input to the ONNX. - - Returns: - `OnnxSeq2SeqConfigWithPast` - """ - if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): - behavior = ConfigBehavior(behavior) - - onnx_config = self.__class__( - self._config, - task=self.task, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - use_past=use_past, - use_past_in_inputs=use_past_in_inputs, - behavior=behavior, - preprocessors=self._preprocessors, - legacy=self.legacy, - stateful=stateful, - ) - onnx_config.variant = self.variant - return onnx_config - @register_in_tasks_manager( "t5", @@ -2307,63 +2249,6 @@ def with_behavior( library_name="transformers", ) class T5OpenVINOConfig(T5OnnxConfig): - def __init__( - self, - config: PretrainedConfig, - task: str = "feature-extraction", - int_dtype: str = "int64", - float_dtype: str = "fp32", - use_past: bool = False, - use_past_in_inputs: bool = False, - behavior: ConfigBehavior = ConfigBehavior.MONOLITH, - preprocessors: Optional[List[Any]] = None, - legacy: bool = False, - stateful: bool = False, - ): - self.stateful = stateful - super().__init__( - config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy - ) - - def with_behavior( - self, - behavior: Union[str, ConfigBehavior], - use_past: bool = False, - use_past_in_inputs: bool = False, - stateful: bool = False, - ) -> "OnnxSeq2SeqConfigWithPast": - """ - Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. - - Args: - behavior ([`ConfigBehavior`]): - The behavior to use for the new instance. - use_past (`bool`, defaults to `False`): - Whether or not the ONNX config to instantiate is for a model using KV cache. - use_past_in_inputs (`bool`, defaults to `False`): - Whether the KV cache is to be passed as an input to the ONNX. - - Returns: - `OnnxSeq2SeqConfigWithPast` - """ - if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): - behavior = ConfigBehavior(behavior) - - onnx_config = self.__class__( - self._config, - task=self.task, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - use_past=use_past, - use_past_in_inputs=use_past_in_inputs, - behavior=behavior, - preprocessors=self._preprocessors, - legacy=self.legacy, - stateful=stateful, - ) - onnx_config.variant = self.variant - return onnx_config - def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( self.task, self._normalized_config, **kwargs @@ -2377,7 +2262,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen self.task, self._normalized_config, encoder_sequence_length=dummy_text_input_generator.sequence_length - if not self.stateful + if not getattr(self, "stateful", False) else dummy_text_input_generator.sequence_length + 2, **kwargs, ) @@ -2414,63 +2299,6 @@ class LongT5OpenVINOConfig(T5OpenVINOConfig): library_name="transformers", ) class M2M100OpenVINOConfig(M2M100OnnxConfig): - def __init__( - self, - config: PretrainedConfig, - task: str = "feature-extraction", - int_dtype: str = "int64", - float_dtype: str = "fp32", - use_past: bool = False, - use_past_in_inputs: bool = False, - behavior: ConfigBehavior = ConfigBehavior.MONOLITH, - preprocessors: Optional[List[Any]] = None, - legacy: bool = False, - stateful: bool = False, - ): - self.stateful = stateful - super().__init__( - config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, behavior, preprocessors, legacy - ) - - def with_behavior( - self, - behavior: Union[str, ConfigBehavior], - use_past: bool = False, - use_past_in_inputs: bool = False, - stateful: bool = False, - ) -> "OnnxSeq2SeqConfigWithPast": - """ - Creates a copy of the current OnnxConfig but with a different `ConfigBehavior` and `use_past` value. - - Args: - behavior ([`ConfigBehavior`]): - The behavior to use for the new instance. - use_past (`bool`, defaults to `False`): - Whether or not the ONNX config to instantiate is for a model using KV cache. - use_past_in_inputs (`bool`, defaults to `False`): - Whether the KV cache is to be passed as an input to the ONNX. - - Returns: - `OnnxSeq2SeqConfigWithPast` - """ - if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): - behavior = ConfigBehavior(behavior) - - onnx_config = self.__class__( - self._config, - task=self.task, - int_dtype=self.int_dtype, - float_dtype=self.float_dtype, - use_past=use_past, - use_past_in_inputs=use_past_in_inputs, - behavior=behavior, - preprocessors=self._preprocessors, - legacy=self.legacy, - stateful=stateful, - ) - onnx_config.variant = self.variant - return onnx_config - def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( self.task, self._normalized_config, **kwargs @@ -2481,7 +2309,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen ) if self.task != "text-generation": kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length - if self.stateful: + if getattr(self, "stateful", False): kwargs["encoder_sequence_length"] = kwargs["encoder_sequence_length"] + 2 dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task]( @@ -2592,5 +2420,76 @@ class MarianOpenVINOConfig(M2M100OpenVINOConfig): ], library_name="transformers", ) -class PegasusOpenVINOConfig(M2M100OnnxConfig): +class PegasusOpenVINOConfig(M2M100OpenVINOConfig): pass + + +@register_in_tasks_manager( + "pix2struct", + *[ + "image-to-text", + "image-to-text-with-past", + ], + library_name="transformers", +) +class Pix2StructOpenVINOConfig(Pix2StructOnnxConfig): + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_inputs_generators = [] + dummy_inputs_generators.append(self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config)) + + if self._preprocessors is None or len(self._preprocessors) != 2: + raise ValueError( + f"Preprocessors for pix2struct need to be available for the ONNX export to infer input static shapes. Got: {self._preprocessors}" + ) + + encoder_sequence_length = self._preprocessors[1].image_processor.max_patches + if getattr(self, "stateful", False): + encoder_sequence_length += 2 + # A hack for DummyPix2StructInputGenerator to gain access to the preprocessors. + # TODO: we should probably pass preprocessors to all dummy input generators. + kwargs["preprocessors"] = self._preprocessors + for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:]: + dummy_inputs_generators.append( + cls_(self.task, self._normalized_config, encoder_sequence_length=encoder_sequence_length, **kwargs) + ) + + return dummy_inputs_generators + + +@register_in_tasks_manager( + "trocr", + *[ + "feature-extraction", + "feature-extraction-with-past", + "image-to-text", + "image-to-text-with-past", + ], + library_name="transformers", +) +class TrOCROpenVINOConfig(TrOCROnnxConfig): + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( + self.task, self._normalized_config, **kwargs + ) + dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1]( + self.task, + self._normalized_config, + **kwargs, + ) + encoder_sequence_length = dummy_text_input_generator.sequence_length + + if getattr(self, "stateful", False): + encoder_sequence_length += 2 + dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2]( + self.task, + self._normalized_config, + encoder_sequence_length=encoder_sequence_length, + **kwargs, + ) + dummy_inputs_generators = [ + dummy_text_input_generator, + dummy_decoder_text_input_generator, + dummy_seq2seq_past_key_values_generator, + ] + + return dummy_inputs_generators diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f7f677bf8c..6302921d0d 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1525,7 +1525,8 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_model.encoder, OVEncoder) self.assertIsInstance(ov_model.decoder, OVDecoder) - self.assertIsInstance(ov_model.decoder_with_past, OVDecoder) + self.assertTrue(ov_model.decoder.stateful) + self.assertIsInstance(ov_model.decoder_with_past, None) self.assertIsInstance(ov_model.config, PretrainedConfig) transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) From 27fb9c610fde8a01c7868e0c562ee42df17e4249 Mon Sep 17 00:00:00 2001 From: eaidova Date: Wed, 26 Jun 2024 17:58:22 +0400 Subject: [PATCH 07/11] WA shapeof issue --- notebooks/ipex/text_generation.ipynb | 8 +- .../openvino/optimum_openvino_inference.ipynb | 4 +- .../openvino/quantized_generation_demo.ipynb | 86 ++++++++++--------- ...stable_diffusion_hybrid_quantization.ipynb | 25 +++--- optimum/exporters/openvino/__main__.py | 4 +- optimum/exporters/openvino/model_configs.py | 75 +--------------- optimum/exporters/openvino/stateful.py | 9 +- optimum/intel/openvino/modeling_seq2seq.py | 10 ++- tests/openvino/test_modeling.py | 2 +- 9 files changed, 78 insertions(+), 145 deletions(-) diff --git a/notebooks/ipex/text_generation.ipynb b/notebooks/ipex/text_generation.ipynb index df46355531..d1a62d9201 100644 --- a/notebooks/ipex/text_generation.ipynb +++ b/notebooks/ipex/text_generation.ipynb @@ -62,13 +62,9 @@ "source": [ "model = IPEXModelForCausalLM.from_pretrained(\"gpt2\", torch_dtype=torch.bfloat16, export=True)\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", - "input_sentence = [\n", - " \"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?\"\n", - "]\n", + "input_sentence = [\"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?\"]\n", "model_inputs = tokenizer(input_sentence, return_tensors=\"pt\")\n", - "generation_kwargs = dict(\n", - " max_new_tokens=32, do_sample=False, num_beams=4, num_beam_groups=1, no_repeat_ngram_size=2, use_cache=True\n", - ")\n", + "generation_kwargs = dict(max_new_tokens=32, do_sample=False, num_beams=4, num_beam_groups=1, no_repeat_ngram_size=2, use_cache=True)\n", "\n", "generated_ids = model.generate(**model_inputs, **generation_kwargs)\n", "output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n", diff --git a/notebooks/openvino/optimum_openvino_inference.ipynb b/notebooks/openvino/optimum_openvino_inference.ipynb index 7ef14e0635..76c77aec55 100644 --- a/notebooks/openvino/optimum_openvino_inference.ipynb +++ b/notebooks/openvino/optimum_openvino_inference.ipynb @@ -466,9 +466,7 @@ "source": [ "# Set the device directly with `.from_pretrained()`\n", "if \"GPU\" in Core().available_devices:\n", - " model = OVModelForQuestionAnswering.from_pretrained(\n", - " \"distilbert-base-uncased-distilled-squad-ov-fp16\", device=\"GPU\"\n", - " )" + " model = OVModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased-distilled-squad-ov-fp16\", device=\"GPU\")" ] }, { diff --git a/notebooks/openvino/quantized_generation_demo.ipynb b/notebooks/openvino/quantized_generation_demo.ipynb index cc5c1ec2b3..5673243cb2 100644 --- a/notebooks/openvino/quantized_generation_demo.ipynb +++ b/notebooks/openvino/quantized_generation_demo.ipynb @@ -121,7 +121,7 @@ " \"CACHE_DIR\": os.path.join(save_name, \"model_cache\"), # OpenVINO will use this directory as cache\n", " },\n", " \"compile\": False,\n", - " \"quantization_config\": quantization_config,\n", + " \"quantization_config\": quantization_config\n", "}\n", "\n", "# Check whether the model was already exported\n", @@ -143,8 +143,8 @@ "\n", "# TODO Optional: export to huggingface/hub\n", "\n", - "model_size = os.stat(os.path.join(save_name, \"openvino_model.bin\")).st_size / 1024**3\n", - "print(f\"Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB\")" + "model_size = os.stat(os.path.join(save_name, \"openvino_model.bin\")).st_size / 1024 ** 3\n", + "print(f'Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB')" ] }, { @@ -212,7 +212,7 @@ "from transformers import TextStreamer\n", "\n", "# Tokenize the sample\n", - "inputs = tokenizer([sample], return_tensors=\"pt\")\n", + "inputs = tokenizer([sample], return_tensors='pt')\n", "\n", "# Call generate on the inputs\n", "out = model.generate(\n", @@ -294,7 +294,7 @@ "\n", "\n", "# Tokenize the sample\n", - "inputs = tokenizer([sample], return_tensors=\"pt\")\n", + "inputs = tokenizer([sample], return_tensors='pt') \n", "\n", "out = stateless_model.generate(\n", " **inputs,\n", @@ -302,7 +302,7 @@ " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n", " pad_token_id=tokenizer.eos_token_id,\n", " prompt_lookup_num_tokens=3,\n", - ")" + ") " ] }, { @@ -358,7 +358,7 @@ " \"CACHE_DIR\": os.path.join(save_name, \"model_cache\"), # OpenVINO will use this directory as cache\n", " },\n", " \"compile\": False,\n", - " \"quantization_config\": quantization_config,\n", + " \"quantization_config\": quantization_config\n", "}\n", "\n", "# Check whether the model was already exported\n", @@ -458,15 +458,15 @@ " if len(self.seq_lens) > 0 or len(self.win_sizes) > 0:\n", " raise RuntimeError(\"Always use a new instance, don't reuse!\")\n", " self.model_forward = self.model.forward\n", - "\n", + " \n", " @wraps(self.model_forward)\n", " def forward_wrapper(**kwargs):\n", " self.seq_lens[-1].append(kwargs.get(\"attention_mask\").shape[-1])\n", " self.win_sizes[-1].append(kwargs.get(\"input_ids\").shape[-1] - 1)\n", " return self.model_forward(**kwargs)\n", - "\n", + " \n", " self.model.forward = forward_wrapper\n", - "\n", + " \n", " # wrap generate method\n", " self.model_generate = self.model.generate\n", "\n", @@ -479,11 +479,10 @@ " out = self.model_generate(*args, **kwargs)\n", " self.seq_lens[-1].append(out.shape[-1])\n", " return out\n", - "\n", " self.model.generate = generate_wrapper\n", " return self\n", "\n", - " def __exit__(self, type, value, traceback):\n", + " def __exit__(self, type, value, traceback):\n", " self.model.forward = self.model_forward\n", " self.model.generate = self.model_generate\n", " self.model_forward = None\n", @@ -495,7 +494,7 @@ " self.seq_lens = [sl[1:] for sl in self.seq_lens]\n", " # Add window size for output to ease calculation later\n", " for ws, sl in zip(self.win_sizes, self.seq_lens):\n", - " ws.append(0)\n", + " ws.append(0) \n", "\n", " def acceptance_rate(self, return_mean=True, normalize=False):\n", " # ar_per_win = ((cur_seq_len - cur_win_size) - (prev_seq_len - prev_win_size) - 1) / prev_win_size\n", @@ -504,8 +503,9 @@ " sl = np.array(sl, dtype=np.float64)\n", " ws = np.array(ws, dtype=np.float64)\n", " out_lens = sl - ws\n", - " accepted = out_lens[1:] - out_lens[:-1] - 1\n", - " ar_per_win.append(np.divide(accepted, ws[:-1], out=np.zeros_like(accepted), where=ws[:-1] != 0))\n", + " accepted = (out_lens[1:] - out_lens[:-1] - 1)\n", + " ar_per_win.append(np.divide(accepted, ws[:-1],\n", + " out=np.zeros_like(accepted),where=ws[:-1] != 0))\n", " ar_per_win = np.hstack(ar_per_win)\n", " # Normalized AR doesn't take into account windows with size 0\n", " if normalize:\n", @@ -544,7 +544,7 @@ "samples_number = 30\n", "with AcceptanceRateRecorder(stateless_model) as ar_recorder:\n", " for text in tqdm(dataset[:samples_number]):\n", - " tokenized_prompt = tokenizer([prompt_template.format(text=text)], return_tensors=\"pt\")\n", + " tokenized_prompt = tokenizer([prompt_template.format(text=text)], return_tensors='pt')\n", " stateless_model.generate(\n", " **tokenized_prompt,\n", " max_new_tokens=128,\n", @@ -623,6 +623,7 @@ " return False\n", "\n", "\n", + "\n", "# Set the chat template to the tokenizer. The chat template implements the simple template of\n", "# User: content\n", "# Assistant: content\n", @@ -650,7 +651,11 @@ " if model_msg:\n", " messages.append({\"role\": \"Assistant\", \"content\": model_msg})\n", " input_token = tokenizer.apply_chat_template(\n", - " messages, add_generation_prompt=True, tokenize=True, return_tensors=\"pt\", return_dict=True\n", + " messages,\n", + " add_generation_prompt=True,\n", + " tokenize=True,\n", + " return_tensors=\"pt\",\n", + " return_dict=True\n", " )\n", " return input_token\n", "\n", @@ -674,18 +679,18 @@ " # Construct the input message string for the model by concatenating the current system message and conversation history\n", " # Tokenize the messages string\n", " inputs = prepare_history_for_model(history)\n", - " input_length = inputs[\"input_ids\"].shape[1]\n", + " input_length = inputs['input_ids'].shape[1]\n", " # truncate input in case it is too long.\n", " # TODO improve this\n", " if input_length > 2000:\n", " history = [history[-1]]\n", " inputs = prepare_history_for_model(history)\n", - " input_length = inputs[\"input_ids\"].shape[1]\n", + " input_length = inputs['input_ids'].shape[1]\n", "\n", " prompt_char = \"▌\"\n", " history[-1][1] = prompt_char\n", " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", - "\n", + " \n", " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", "\n", " # Create a stopping criteria to prevent the model from playing the role of the user aswell.\n", @@ -701,14 +706,11 @@ " eos_token_id=[tokenizer.eos_token_id],\n", " pad_token_id=tokenizer.eos_token_id,\n", " )\n", - " generate_kwargs = (\n", - " dict(\n", - " streamer=streamer,\n", - " generation_config=generation_config,\n", - " stopping_criteria=stopping_criteria,\n", - " )\n", - " | inputs\n", - " )\n", + " generate_kwargs = dict(\n", + " streamer=streamer,\n", + " generation_config=generation_config,\n", + " stopping_criteria=stopping_criteria,\n", + " ) | inputs\n", "\n", " if assisted:\n", " target_generate = stateless_model.generate\n", @@ -735,7 +737,7 @@ " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", " history[-1][1] = partial_text\n", " generation_time = time.perf_counter() - start\n", - " yield history, f\"Generation time: {generation_time:.2f} sec\", *([gr.update(interactive=True)] * 4)" + " yield history, f'Generation time: {generation_time:.2f} sec', *([gr.update(interactive=True)] * 4)" ] }, { @@ -779,9 +781,7 @@ " [\"Can you explain to me briefly what is Python programming language?\"],\n", " [\"Explain the plot of Cinderella in a sentence.\"],\n", " [\"Write a Python function to perform binary search over a sorted list. Use markdown to write code\"],\n", - " [\n", - " \"Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?\"\n", - " ],\n", + " [\"Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?\"],\n", "]\n", "\n", "\n", @@ -797,7 +797,7 @@ " \"\"\"\n", " # Append current user message to history with a blank assistant message which will be generated by the model\n", " history.append([message, None])\n", - " return (\"\", history)\n", + " return ('', history)\n", "\n", "\n", "def prepare_for_regenerate(history):\n", @@ -808,7 +808,7 @@ " history: conversation history\n", " Returns:\n", " updated history\n", - " \"\"\"\n", + " \"\"\" \n", " history[-1][1] = None\n", " return history\n", "\n", @@ -821,7 +821,7 @@ " msg = gr.Textbox(placeholder=\"Enter message here...\", show_label=False, autofocus=True, scale=75)\n", " status = gr.Textbox(\"Status: Idle\", show_label=False, max_lines=1, scale=15)\n", " with gr.Row():\n", - " submit = gr.Button(\"Submit\", variant=\"primary\")\n", + " submit = gr.Button(\"Submit\", variant='primary')\n", " regenerate = gr.Button(\"Regenerate\")\n", " clear = gr.Button(\"Clear\")\n", " with gr.Accordion(\"Advanced Options:\", open=False):\n", @@ -860,7 +860,9 @@ " step=0.1,\n", " interactive=True,\n", " )\n", - " gr.Examples(EXAMPLES, inputs=msg, label=\"Click on any example and press the 'Submit' button\")\n", + " gr.Examples(\n", + " EXAMPLES, inputs=msg, label=\"Click on any example and press the 'Submit' button\"\n", + " )\n", "\n", " # Sets generate function to be triggered when the user submit a new message\n", " gr.on(\n", @@ -874,14 +876,20 @@ " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", " concurrency_limit=1,\n", - " queue=True,\n", + " queue=True\n", " )\n", - " regenerate.click(fn=prepare_for_regenerate, inputs=chatbot, outputs=chatbot, queue=True, concurrency_limit=1).then(\n", + " regenerate.click(\n", + " fn=prepare_for_regenerate,\n", + " inputs=chatbot,\n", + " outputs=chatbot,\n", + " queue=True,\n", + " concurrency_limit=1\n", + " ).then(\n", " fn=generate,\n", " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", " concurrency_limit=1,\n", - " queue=True,\n", + " queue=True\n", " )\n", " clear.click(fn=lambda: (None, \"Status: Idle\"), inputs=None, outputs=[chatbot, status], queue=False)" ] diff --git a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb index d89457bd78..8ef2e8ad6c 100644 --- a/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb +++ b/notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb @@ -167,7 +167,6 @@ "def preprocess_fn(example):\n", " return {\"prompt\": example[\"caption\"]}\n", "\n", - "\n", "NUM_SAMPLES = 200\n", "dataset = dataset.take(NUM_SAMPLES)\n", "calibration_dataset = dataset.map(lambda x: preprocess_fn(x), remove_columns=dataset.column_names)" @@ -1067,14 +1066,12 @@ ], "source": [ "int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True)\n", - "quantization_config = OVWeightQuantizationConfig(\n", - " bits=8, num_samples=NUM_SAMPLES, quant_method=OVQuantizationMethod.HYBRID\n", - ")\n", + "quantization_config = OVWeightQuantizationConfig(bits=8, num_samples=NUM_SAMPLES, quant_method=OVQuantizationMethod.HYBRID)\n", "quantizer = OVQuantizer(int8_pipe)\n", "quantizer.quantize(\n", " ov_config=OVConfig(quantization_config=quantization_config),\n", " calibration_dataset=calibration_dataset,\n", - " save_directory=int8_model_path,\n", + " save_directory=int8_model_path\n", ")" ] }, @@ -1205,10 +1202,8 @@ " im_w, im_h = fp32_img.size\n", " is_horizontal = im_h <= im_w\n", " figsize = (20, 30) if is_horizontal else (30, 20)\n", - " fig, axs = plt.subplots(\n", - " 1 if is_horizontal else 2, 2 if is_horizontal else 1, figsize=figsize, sharex=\"all\", sharey=\"all\"\n", - " )\n", - " fig.patch.set_facecolor(\"white\")\n", + " fig, axs = plt.subplots(1 if is_horizontal else 2, 2 if is_horizontal else 1, figsize=figsize, sharex='all', sharey='all')\n", + " fig.patch.set_facecolor('white')\n", " list_axes = list(axs.flat)\n", " for a in list_axes:\n", " a.set_xticklabels([])\n", @@ -1222,7 +1217,7 @@ " img2_title = \"INT8 result\"\n", " list_axes[0].set_title(img1_title, fontsize=20)\n", " list_axes[1].set_title(img2_title, fontsize=20)\n", - " fig.subplots_adjust(wspace=0.0 if is_horizontal else 0.01, hspace=0.01 if is_horizontal else 0.0)\n", + " fig.subplots_adjust(wspace=0.0 if is_horizontal else 0.01 , hspace=0.01 if is_horizontal else 0.0)\n", " fig.tight_layout()" ] }, @@ -1235,10 +1230,13 @@ "source": [ "prompt = \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\"\n", "\n", - "\n", "def generate_image(pipeline, prompt):\n", " transformers.set_seed(1)\n", - " return pipeline(prompt=prompt, guidance_scale=8.0, output_type=\"pil\").images[0]" + " return pipeline(\n", + " prompt=prompt,\n", + " guidance_scale=8.0,\n", + " output_type=\"pil\"\n", + " ).images[0]" ] }, { @@ -1331,7 +1329,7 @@ "def get_model_size(model_folder, framework):\n", " \"\"\"\n", " Return OpenVINO or PyTorch model size in Mb.\n", - "\n", + " \n", " Arguments:\n", " model_folder:\n", " Directory containing a model.\n", @@ -1533,7 +1531,6 @@ "def get_val_dataset(num_items=3):\n", " return [item[\"caption\"] for item in dataset.take(num_items)]\n", "\n", - "\n", "def benchmark(pipeline, dataset):\n", " \"\"\"\n", " Benchmark PyTorch or OpenVINO model. This function does inference on `num_items`\n", diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index abf2e0e108..e1fe1d4017 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -27,6 +27,7 @@ from openvino.runtime import Core, Type, save_model from optimum.exporters import TasksManager +from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from optimum.exporters.onnx.base import OnnxConfig from optimum.exporters.openvino.convert import export_from_model from optimum.intel.utils.import_utils import ( @@ -39,7 +40,6 @@ _infer_library_from_model_name_or_path, _OpenClipForZeroShotImageClassification, ) -from optimum.intel.utils.import_utils import is_openvino_tokenizers_available from optimum.utils.save_utils import maybe_load_preprocessors from .utils import _MAX_UNCOMPRESSED_SIZE, MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry @@ -267,6 +267,8 @@ def main_export( # some models force flash_attn attention by default that does not support load model on cpu if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES: loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type] + if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: + loading_kwargs["attn_implementation"] = "eager" # there are some difference between remote and in library representation of past key values for some models, # for avoiding confusion we disable remote code for them if ( diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 8cd8946a9a..1993bf08c1 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -41,9 +41,7 @@ M2M100OnnxConfig, MPTOnnxConfig, PhiOnnxConfig, - Pix2StructOnnxConfig, T5OnnxConfig, - TrOCROnnxConfig, UNetOnnxConfig, VisionOnnxConfig, VaeDecoderOnnxConfig, @@ -2236,7 +2234,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch size. Override this method for custom behavior. """ - if getattr(self, "stateful"): + if getattr(self, "stateful", False): if "encoder_sequence_length" not in kwargs: sequence_len = kwargs.get("sequence_length", DEFAULT_DUMMY_SHAPES["sequence_length"]) kwargs["encoder_sequence_length"] = sequence_len + 2 @@ -2422,74 +2420,3 @@ class MarianOpenVINOConfig(M2M100OpenVINOConfig): ) class PegasusOpenVINOConfig(M2M100OpenVINOConfig): pass - - -@register_in_tasks_manager( - "pix2struct", - *[ - "image-to-text", - "image-to-text-with-past", - ], - library_name="transformers", -) -class Pix2StructOpenVINOConfig(Pix2StructOnnxConfig): - def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: - dummy_inputs_generators = [] - dummy_inputs_generators.append(self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config)) - - if self._preprocessors is None or len(self._preprocessors) != 2: - raise ValueError( - f"Preprocessors for pix2struct need to be available for the ONNX export to infer input static shapes. Got: {self._preprocessors}" - ) - - encoder_sequence_length = self._preprocessors[1].image_processor.max_patches - if getattr(self, "stateful", False): - encoder_sequence_length += 2 - # A hack for DummyPix2StructInputGenerator to gain access to the preprocessors. - # TODO: we should probably pass preprocessors to all dummy input generators. - kwargs["preprocessors"] = self._preprocessors - for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:]: - dummy_inputs_generators.append( - cls_(self.task, self._normalized_config, encoder_sequence_length=encoder_sequence_length, **kwargs) - ) - - return dummy_inputs_generators - - -@register_in_tasks_manager( - "trocr", - *[ - "feature-extraction", - "feature-extraction-with-past", - "image-to-text", - "image-to-text-with-past", - ], - library_name="transformers", -) -class TrOCROpenVINOConfig(TrOCROnnxConfig): - def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: - dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0]( - self.task, self._normalized_config, **kwargs - ) - dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1]( - self.task, - self._normalized_config, - **kwargs, - ) - encoder_sequence_length = dummy_text_input_generator.sequence_length - - if getattr(self, "stateful", False): - encoder_sequence_length += 2 - dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2]( - self.task, - self._normalized_config, - encoder_sequence_length=encoder_sequence_length, - **kwargs, - ) - dummy_inputs_generators = [ - dummy_text_input_generator, - dummy_decoder_text_input_generator, - dummy_seq2seq_past_key_values_generator, - ] - - return dummy_inputs_generators diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 4482455a6c..8377742657 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -198,10 +198,7 @@ def ensure_export_task_support_stateful(task: str, is_encoder_decoder: bool = Fa _ENCODER_DECODER_TASKS_WITH_PAST = ( "automatic-speech-recognition", - "document-question-answering", - "image-to-text", "text2text-generation", - "visual-question-answering", ) is_stateful = task.endswith("-with-past") and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST @@ -232,6 +229,10 @@ def get_read_value_ops(model: ov.Model): return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"] +def get_shape_of_ops(model: ov.Model): + return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"] + + def get_consumer_nodes(node): consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()]) return set(input.get_node() for input in consumer_inputs) @@ -239,7 +240,7 @@ def get_consumer_nodes(node): def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list): # Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's - other_inputs = set(model.get_parameters() + get_read_value_ops(model)) - set(sources) + other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources) other_nodes = find_dependent_nodes(model, other_inputs) source_dependent_nodes = find_dependent_nodes(model, sources) # TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 5012ea8697..0262a25818 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -657,7 +657,7 @@ def forward( if "beam_idx" in self.input_names: batch_size = input_ids.shape[0] inputs["beam_idx"] = ( - self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) + self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=np.int32) ) # Run inference self.request.start_async(inputs, share_inputs=True) @@ -818,7 +818,9 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng if is_decoder: if inputs.get_any_name().startswith("past_key_values"): shapes[inputs][2] = -1 - elif not inputs.get_any_name().startswith("encoder"): + elif not inputs.get_any_name().startswith("encoder") and not inputs.get_any_name().startswith( + "beam_idx" + ): shapes[inputs][1] = -1 model.reshape(shapes) return model @@ -901,7 +903,9 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng if is_decoder: if inputs.get_any_name().startswith("past_key_values"): shapes[inputs][2] = -1 - elif not inputs.get_any_name().startswith("encoder"): + elif not inputs.get_any_name().startswith("encoder") and not inputs.get_any_name().startswith( + "beam_idx" + ): shapes[inputs][1] = -1 model.reshape(shapes) return model diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 6302921d0d..c7a70f8139 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -1526,7 +1526,7 @@ def test_compare_to_transformers(self, model_arch): self.assertIsInstance(ov_model.encoder, OVEncoder) self.assertIsInstance(ov_model.decoder, OVDecoder) self.assertTrue(ov_model.decoder.stateful) - self.assertIsInstance(ov_model.decoder_with_past, None) + self.assertTrue(ov_model.decoder_with_past is None) self.assertIsInstance(ov_model.config, PretrainedConfig) transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) From 1e28c29c6e9309a1a8947a58b6c5258936e5b03f Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 4 Jul 2024 11:10:24 +0400 Subject: [PATCH 08/11] fix loading non-stateful --- optimum/exporters/openvino/convert.py | 5 +++-- optimum/intel/openvino/modeling_seq2seq.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index a4c783312c..405a311c10 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -651,15 +651,16 @@ def export_from_model( logging.disable(logging.INFO) + if library_name == "open_clip": custom_architecture = True custom_export_configs, fn_get_submodels = _get_open_clip_submodels_fn_and_export_configs( model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels ) - elif ( + elif ( + stateful and (task.startswith(TasksManager._ENCODER_DECODER_TASKS) and getattr(model.config, "is_encoder_decoder", False)) - and stateful and not custom_architecture ): export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 0262a25818..3e4f7180aa 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -327,7 +327,7 @@ def __init__( self.encoder = OVEncoder(self.encoder_model, parent_model=self) self.decoder = OVDecoder(self.decoder_model, parent_model=self) - if self.use_cache and self.decoder_with_past is not None: + if self.use_cache and self.decoder_with_past_model is not None: self.decoder_with_past = OVDecoder(self.decoder_with_past_model, parent_model=self) if enable_compilation: self.compile() From 77667e1093a3d91c41c06f453b9a03241467f947 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 4 Jul 2024 11:24:09 +0400 Subject: [PATCH 09/11] fix loading decoder with past conditions --- optimum/exporters/openvino/__main__.py | 2 +- optimum/exporters/openvino/convert.py | 9 ++--- .../intel/openvino/modeling_base_seq2seq.py | 34 ++++++++++++++++--- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index e1fe1d4017..973c75926b 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -27,8 +27,8 @@ from openvino.runtime import Core, Type, save_model from optimum.exporters import TasksManager -from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from optimum.exporters.onnx.base import OnnxConfig +from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED from optimum.exporters.openvino.convert import export_from_model from optimum.intel.utils.import_utils import ( is_nncf_available, diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 405a311c10..b05c372172 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -651,16 +651,17 @@ def export_from_model( logging.disable(logging.INFO) - if library_name == "open_clip": custom_architecture = True custom_export_configs, fn_get_submodels = _get_open_clip_submodels_fn_and_export_configs( model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels ) - elif ( - stateful and - (task.startswith(TasksManager._ENCODER_DECODER_TASKS) and getattr(model.config, "is_encoder_decoder", False)) + elif ( + stateful + and ( + task.startswith(TasksManager._ENCODER_DECODER_TASKS) and getattr(model.config, "is_encoder_decoder", False) + ) and not custom_architecture ): export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 08b44e258e..6f7933d50e 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -234,8 +234,6 @@ def _from_pretrained( # Load model from hub else: model_file_names = {"encoder": encoder_file_name, "decoder": decoder_file_name} - if use_cache: - model_file_names["decoder_with_past"] = decoder_with_past_file_name # If not ONNX then OpenVINO IR : adds binary files if not from_onnx: @@ -259,7 +257,21 @@ def _from_pretrained( if not compile_only: encoder = cls.load_model(file_names["encoder"], quantization_config) decoder = cls.load_model(file_names["decoder"], quantization_config) - if use_cache: + if use_cache and not model_has_state(decoder): + model_file_names["decoder_with_past"] = decoder_with_past_file_name + model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") + for name in ["decoder_with_past", "decoder_with_past_bin"]: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) + file_names[name] = model_cache_path decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config) else: encoder = cls._compile_model( @@ -268,7 +280,21 @@ def _from_pretrained( decoder = cls._compile_model( file_names["decoder"], kwargs.get("device", "CPU"), kwargs.get("ov_config"), model_save_dir ) - if use_cache: + if use_cache and not model_has_state(decoder): + model_file_names["decoder_with_past"] = decoder_with_past_file_name + model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") + for name in ["decoder_with_past", "decoder_with_past_bin"]: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) + file_names[name] = model_cache_path decoder_with_past = cls._compile_model( file_names["decoder_with_past"], kwargs.get("device", "CPU"), From 429784f8998e277e168a9a27ca0b6f8469ab64f7 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 4 Jul 2024 11:28:55 +0400 Subject: [PATCH 10/11] update quantization tests --- optimum/intel/openvino/modeling_base_seq2seq.py | 4 +++- optimum/intel/openvino/modeling_seq2seq.py | 4 ++-- tests/openvino/test_exporters_cli.py | 2 +- tests/openvino/test_quantization.py | 11 ++++++----- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 6f7933d50e..73e39391cf 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -319,7 +319,6 @@ def _from_pretrained( logger.info( "Generation config file not found, using a generation config created from the model config." ) - return cls( encoder=encoder, decoder=decoder, @@ -393,6 +392,8 @@ def _from_transformers( ov_config = None else: ov_config = OVConfig(dtype="fp32") + + stateful = kwargs.get("stateful", True) main_export( model_name_or_path=model_id, @@ -406,6 +407,7 @@ def _from_transformers( force_download=force_download, trust_remote_code=trust_remote_code, ov_config=ov_config, + stateful=stateful ) return cls._from_pretrained( diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 3e4f7180aa..c36364fbf0 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -380,7 +380,7 @@ def forward( encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # Decode - if past_key_values is None or self.decoder_with_past is None: + if past_key_values is None or self.decoder.stateful: decoder_outputs = self.decoder( input_ids=decoder_input_ids[:, -1:] if past_key_values is not None else decoder_input_ids, encoder_hidden_states=encoder_outputs.last_hidden_state, @@ -481,7 +481,7 @@ class OVEncoder: The OpenVINO inference request associated to the encoder. """ - def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM, merged=False): + def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM): self.model = model self.parent_model = parent_model self._comple_only = parent_model._compile_only diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 67511bb845..d3f2d5b4ff 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -261,7 +261,7 @@ def test_exporters_cli_int8(self, task: str, model_type: str): if task.startswith("text2text-generation"): models = [model.encoder, model.decoder] - if task.endswith("with-past"): + if task.endswith("with-past") and not model.decoder.stateful: models.append(model.decoder_with_past) elif model_type.startswith("stable-diffusion") or model_type.startswith("flux"): models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder] diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 2869acf834..f74d949188 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -566,7 +566,7 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust self.assertEqual(model._openvino_config.dtype, "int8") if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder, model.decoder_with_past] + models = [model.encoder, model.decoder] elif model.export_feature == "text-to-image": models = [model.unet, model.vae_encoder, model.vae_decoder] models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2) @@ -706,8 +706,8 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type, tru MODEL_NAMES[model_type], export=True, load_in_8bit=False, trust_remote_code=trust_remote_code ) if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder, model.decoder_with_past] - elif model.export_feature == "text-to-image": + models = [model.encoder, model.decoder] + elif model.export_feature.startswith("text-to-image"): models = [model.unet, model.vae_encoder, model.vae_decoder] models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2) elif model_type == "open-clip": @@ -1126,11 +1126,12 @@ def _generate_random_audio_data(processor): @parameterized.expand(itertools.product(MODEL_ID, APPLY_CACHING)) def test_calibration_data_uniqueness(self, model_id, apply_caching): ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True) + self.assertTrue(ov_model.decoder_with_past is None) processor = AutoProcessor.from_pretrained(model_id) calibration_data = [] - ov_model.decoder_with_past.request = InferRequestWrapper( - ov_model.decoder_with_past.request, calibration_data, apply_caching=apply_caching + ov_model.decoder.request = InferRequestWrapper( + ov_model.decoder.request, calibration_data, apply_caching=apply_caching ) for _ in range(2): input_features = self._generate_random_audio_data(processor) From fbd7f9e961cac0c3c2f7e94a2494d8140780ad6a Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 22 Nov 2024 21:44:54 +0400 Subject: [PATCH 11/11] update whisper on the latest transformers --- optimum/exporters/openvino/__main__.py | 4 +- optimum/exporters/openvino/convert.py | 6 +- optimum/exporters/openvino/model_configs.py | 21 ++--- optimum/exporters/openvino/model_patcher.py | 55 +++++++++++- optimum/exporters/openvino/stateful.py | 2 + .../intel/openvino/modeling_base_seq2seq.py | 40 ++++----- optimum/intel/openvino/modeling_seq2seq.py | 86 +++++++++++++++++-- 7 files changed, 168 insertions(+), 46 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 973c75926b..639286988b 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -267,8 +267,8 @@ def main_export( # some models force flash_attn attention by default that does not support load model on cpu if is_transformers_version(">=", "4.36") and model_type in FORCE_ATTN_MODEL_CLASSES: loading_kwargs["_attn_implementation"] = FORCE_ATTN_MODEL_CLASSES[model_type] - if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: - loading_kwargs["attn_implementation"] = "eager" + # if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: + # loading_kwargs["attn_implementation"] = "eager" # there are some difference between remote and in library representation of past key values for some models, # for avoiding confusion we disable remote code for them if ( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index b05c372172..38bebc4f8b 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -648,7 +648,7 @@ def export_from_model( input_shapes[input_name] = ( kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] ) - + logging.disable(logging.INFO) if library_name == "open_clip": @@ -667,7 +667,7 @@ def export_from_model( export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default" ) - stateful = [False, True] + stateful_submodels = [False, True] elif library_name == "diffusers": export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino") @@ -911,7 +911,6 @@ def _add_version_info_to_model(model: Model, library_name: Optional[str] = None) return model - def _get_multi_modal_submodels_and_export_configs( model: Union["PreTrainedModel", "TFPreTrainedModel"], task: str, @@ -950,6 +949,7 @@ def _get_multi_modal_submodels_and_export_configs( stateful_parts.append(stateful if getattr(model_part_config, "use_past", False) else False) return main_config, models_for_export, stateful_parts + def _get_submodels_and_export_configs( model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"], task: str, diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 1993bf08c1..5427f92839 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -48,6 +48,7 @@ VaeEncoderOnnxConfig, WhisperOnnxConfig, ) +from optimum.exporters.onnx.base import OnnxConfig from optimum.exporters.onnx.model_patcher import ModelPatcher from optimum.exporters.tasks import TasksManager from optimum.utils import DEFAULT_DUMMY_SHAPES, DummyInputGenerator @@ -99,6 +100,7 @@ QwenModelPatcher, RotaryEmbPatcher, UpdateCausalMaskModelPatcher, + WhisperStatefulDecoderPatcher, XverseModelPatcher, ) @@ -2227,18 +2229,13 @@ def patch_model_for_export( library_name="transformers", ) class WhisperOpenVINOConfig(WhisperOnnxConfig): - def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: - """ - Instantiates the dummy input generators from `self.DUMMY_INPUT_GENERATOR_CLASSES`. - Each dummy input generator is independent, so this method instantiates the first generator, and - forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch - size. Override this method for custom behavior. - """ - if getattr(self, "stateful", False): - if "encoder_sequence_length" not in kwargs: - sequence_len = kwargs.get("sequence_length", DEFAULT_DUMMY_SHAPES["sequence_length"]) - kwargs["encoder_sequence_length"] = sequence_len + 2 - return super()._create_dummy_input_generator_classes(**kwargs) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER: + print("HERE") + return WhisperStatefulDecoderPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) @register_in_tasks_manager( diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 58659e637b..88b417a3f1 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -24,7 +24,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from transformers.utils import is_tf_available -from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments +from optimum.exporters.onnx.model_patcher import ( + DecoderModelPatcher, + ModelPatcher, + override_arguments, + Seq2SeqModelPatcher, +) from optimum.intel.utils.import_utils import ( _openvino_version, _torch_version, @@ -3237,3 +3242,51 @@ def __init__( def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + + +class WhisperStatefulDecoderPatcher(Seq2SeqModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + model.__orig_forward = model.forward + + @functools.wraps(model.__orig_forward) + def patched_forward(*args, **kwargs): + from transformers.cache_utils import EncoderDecoderCache + + print("HERE!!!") + + signature = inspect.signature(self.orig_forward) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + + return_legacy_cache = False + pkv_in_args = False + legacy_pkv = None + if "past_key_values" in kwargs: + legacy_pkv = kwargs.pop("past_key_values", None) + sign_names = list(signature.parameters.keys()) + pkv_argument_index = sign_names.index("past_key_values") + if legacy_pkv is None and len(args) > pkv_argument_index: + legacy_pkv = args[pkv_argument_index] + pkv_in_args = True + if legacy_pkv is not None: + only_self_cache = [cache_item[:2] for cache_item in legacy_pkv] + pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache) + return_legacy_cache = True + if not pkv_in_args: + kwargs["past_key_values"] = pkv + else: + args[pkv_argument_index] = pkv + + outputs = model.__orig_forward(*args, **kwargs) + if return_legacy_cache: + outputs.past_key_values = outputs.past_key_values.to_legacy_cache() + + return outputs + + model.forward = patched_forward + + super().__init__(config, model, model_kwargs) diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 8377742657..cf58a3693e 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -311,11 +311,13 @@ def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model): def patch_stateful_encoder_decoder(config, ov_model): + log.warn(ov_model) encoder_key_value_input_names = [ key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names()) ] + log.warn(encoder_key_value_input_names) remove_parameters_by_names(ov_model, encoder_key_value_input_names) patch_stateful_decoder(config, ov_model) insert_state_for_nodes( diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 73e39391cf..d16c533441 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -262,15 +262,15 @@ def _from_pretrained( model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") for name in ["decoder_with_past", "decoder_with_past_bin"]: model_cache_path = hf_hub_download( - repo_id=model_id, - filename=model_file_names[name], - token=token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - subfolder=subfolder, - ) + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) file_names[name] = model_cache_path decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config) else: @@ -285,15 +285,15 @@ def _from_pretrained( model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") for name in ["decoder_with_past", "decoder_with_past_bin"]: model_cache_path = hf_hub_download( - repo_id=model_id, - filename=model_file_names[name], - token=token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - subfolder=subfolder, - ) + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) file_names[name] = model_cache_path decoder_with_past = cls._compile_model( file_names["decoder_with_past"], @@ -392,7 +392,7 @@ def _from_transformers( ov_config = None else: ov_config = OVConfig(dtype="fp32") - + stateful = kwargs.get("stateful", True) main_export( @@ -407,7 +407,7 @@ def _from_transformers( force_download=force_download, trust_remote_code=trust_remote_code, ov_config=ov_config, - stateful=stateful + stateful=stateful, ) return cls._from_pretrained( diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index c36364fbf0..c4d9621b1b 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -37,7 +37,8 @@ from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE from ...exporters.openvino.stateful import model_has_state from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM -from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties, is_transformers_version +from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties +from ..utils import is_transformers_version if is_transformers_version(">=", "4.43.0"): @@ -387,6 +388,7 @@ def forward( encoder_attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, + cache_position=cache_position, ) else: decoder_outputs = self.decoder_with_past( @@ -628,6 +630,7 @@ def forward( if self.stateful and past_key_values is None: self.request.reset_state() + self._past_len = 0 if past_key_values is not None and not self.stateful: # Flatten the past_key_values @@ -651,7 +654,9 @@ def forward( if "decoder_attention_mask" in self.input_names and decoder_attention_mask is not None: inputs["decoder_attention_mask"] = decoder_attention_mask - if "cache_position" in self.input_names and cache_position is not None: + if "cache_position" in self.input_names: + if cache_position is None: + cache_position = torch.arange(self._past_len, self._past_len + input_ids.shape[1]) inputs["cache_position"] = cache_position if "beam_idx" in self.input_names: @@ -664,9 +669,13 @@ def forward( self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the - # self-attention layer and 2 to the cross-attention layer) - out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) + self._past_len += input_ids.shape[1] + + out_past_key_values = () + if not self.stateful: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the + # self-attention layer and 2 to the cross-attention layer) + out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) # Tuple of tuple of length `n_layers`, with each tuple of length equal to: # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) @@ -1002,9 +1011,6 @@ class _OVModelForWhisper(OVModelForSpeechSeq2Seq, WhisperForConditionalGeneratio """ auto_model_class = WhisperForConditionalGeneration - - # force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods - prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation generate = WhisperForConditionalGeneration.generate @classmethod @@ -1032,3 +1038,67 @@ def __init__(self, stride): # a dummy model attribute that's used in the generate method to compute the input stride # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] model = DummyWhisperModel() + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + decoder_attention_mask=None, + cache_position=None, + **kwargs, + ): + # Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time + # this function needs to be touched, let's try to sort out the commonalities between the two and remove the + # overwrite. + + decoder_position_ids = None + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) + + past_length = 0 + if past_key_values is not None: + if self.decoder.stateful: + past_length = getattr(self.decoder, "_past_len", 0) + else: + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + if decoder_position_ids is not None: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format) + + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device + ) + elif use_cache: + cache_position = cache_position[-decoder_input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + decoder_input_ids = decoder_input_ids.contiguous() + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "use_cache": use_cache, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "cache_position": cache_position, + }