diff --git a/docs/source/inference.mdx b/docs/source/inference.mdx index 65480c1d2f..e0b60baa2e 100644 --- a/docs/source/inference.mdx +++ b/docs/source/inference.mdx @@ -99,21 +99,22 @@ tokenizer.save_pretrained(save_directory) ### Weight-only quantization -You can also apply 8-bit or 4-bit weight quantization when exporting your model with the CLI by setting the `weight-format` argument to respectively `int8` or `int4`: +You can also apply fp16, 8-bit or 4-bit weight compression on the Linear, Convolutional and Embedding layers when exporting your model with the CLI by setting `--weight-format` to respectively `fp16`, `int8` or `int4`: ```bash optimum-cli export openvino --model gpt2 --weight-format int8 ov_model ``` -This will result in the exported model linear and embedding layers to be quantized to INT8 or INT4, the activations will be kept in floating point precision. This type of optimization allows reducing the footprint and latency of LLMs. +This type of optimization allows to reduce the memory footprint and inference latency. -By default the quantization scheme will be [assymmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization), to make it [symmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) you can add `--sym`. + +By default the quantization scheme will be [asymmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization), to make it [symmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) you can add `--sym`. For INT4 quantization you can also specify the following arguments : * The `--group-size` parameter will define the group size to use for quantization, `-1` it will results in per-column quantization. * The `--ratio` parameter controls the ratio between 4-bit and 8-bit quantization. If set to 0.9, it means that 90% of the layers will be quantized to `int4` while 10% will be quantized to `int8`. -Smaller `group_size` and `ratio` of usually improve accuracy at the sacrifice of the model size and inference latency. +Smaller `group_size` and `ratio` values usually improve accuracy at the sacrifice of the model size and inference latency. You can also apply 8-bit quantization on your model's weight when loading your model by setting the `load_in_8bit=True` argument when calling the `from_pretrained()` method. @@ -125,7 +126,7 @@ model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True) -`load_in_8bit` is enabled by default for the models larger than 1 billion parameters. +`load_in_8bit` is enabled by default for the models larger than 1 billion parameters. You can disable it with `load_in_8bit=False`. diff --git a/docs/source/optimization_ov.mdx b/docs/source/optimization_ov.mdx index 088b78f0d3..1e78c36805 100644 --- a/docs/source/optimization_ov.mdx +++ b/docs/source/optimization_ov.mdx @@ -19,15 +19,72 @@ limitations under the License. 🤗 Optimum Intel provides an `openvino` package that enables you to apply a variety of model compression methods such as quantization, pruning, on many models hosted on the 🤗 hub using the [NNCF](https://docs.openvino.ai/2022.1/docs_nncf_introduction.html) framework. -## Post-training optimization +## Post-training -Post-training static quantization introduces an additional calibration step where data is fed through the network in order to compute the activations quantization parameters. -Here is how to apply static quantization on a fine-tuned DistilBERT: +Quantization is a technique to reduce the computational and memory costs of running inference by representing the weights and / or the activations with lower precision data types like 8-bit or 4-bit. + +### Weight-only quantization + +Quantization can be applied on the model's Linear, Convolutional and Embedding layers, enabling the loading of large models on memory-limited devices. For example, when applying 8-bit quantization, the resulting model will be x4 smaller than its fp32 counterpart. For 4-bit quantization, the reduction in memory could theoretically reach x8, but is closer to x6 in practice. + + +#### 8-bit + +For the 8-bit weight quantization you can set `load_in_8bit=True` to load your model's weights in 8-bit: ```python -from functools import partial -from transformers import AutoTokenizer -from optimum.intel import OVConfig, OVQuantizer, OVModelForSequenceClassification, +from optimum.intel import OVModelForCausalLM + +model_id = "helenai/gpt2-ov" +model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True) + +# Saves the int8 model that will be x4 smaller than its fp32 counterpart +model.save_pretrained(saving_directory) +``` + + + +`load_in_8bit` is enabled by default for the models larger than 1 billion parameters. You can disable it with `load_in_8bit=False`. + + + +You can also provide a `quantization_config` instead to specify additional optimization parameters. + +#### 4-bit + +For the 4-bit weight quantization, you need a `quantization_config` to define the optimization parameters, for example: + +```python +from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig + +quantization_config = OVWeightQuantizationConfig(bits=4) +model = OVModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) +``` + +You can tune quantization parameters to achieve a better performance accuracy trade-off as follows: + +```python +quantization_config = OVWeightQuantizationConfig(bits=4, sym=False, ratio=0.8, dataset="ptb") +``` + +By default the quantization scheme will be [asymmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#asymmetric-quantization), to make it [symmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/Quantization.md#symmetric-quantization) you can add `sym=True`. + +For 4-bit quantization you can also specify the following arguments in the quantization configuration : +* The `group_size` parameter will define the group size to use for quantization, `-1` it will results in per-column quantization. +* The `ratio` parameter controls the ratio between 4-bit and 8-bit quantization. If set to 0.9, it means that 90% of the layers will be quantized to `int4` while 10% will be quantized to `int8`. + +Smaller `group_size` and `ratio` values usually improve accuracy at the sacrifice of the model size and inference latency. + +### Static quantization + +When applying post-training static quantization, both the weights and the activations are quantized. +To apply quantization on the activations, an additional calibration step is needed which consists in feeding a `calibration_dataset` to the network in order to estimate the quantization activations parameters. + +Here is how to apply static quantization on a fine-tuned DistilBERT given your own `calibration_dataset`: + +```python +from transformers import AutoTokenizer +from optimum.intel import OVQuantizer, OVModelForSequenceClassification, model_id = "distilbert-base-uncased-finetuned-sst-2-english" model = OVModelForSequenceClassification.from_pretrained(model_id, export=True) @@ -35,11 +92,22 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) # The directory where the quantized model will be saved save_dir = "ptq_model" +quantizer = OVQuantizer.from_pretrained(model) + +# Apply static quantization and export the resulting quantized model to OpenVINO IR format +quantizer.quantize(calibration_dataset=calibration_dataset, save_directory=save_dir) +# Save the tokenizer +tokenizer.save_pretrained(save_dir) +``` + +The calibration dataset can also be created easily using your `OVQuantizer`: + +```python +from functools import partial + def preprocess_function(examples, tokenizer): return tokenizer(examples["sentence"], padding="max_length", max_length=128, truncation=True) -# Instantiate our OVQuantizer using the desired configuration -quantizer = OVQuantizer.from_pretrained(model) # Create the calibration dataset used to perform static quantization calibration_dataset = quantizer.get_calibration_dataset( "glue", @@ -48,59 +116,39 @@ calibration_dataset = quantizer.get_calibration_dataset( num_samples=300, dataset_split="train", ) -# Apply static quantization and export the resulting quantized model to OpenVINO IR format -quantizer.quantize( - calibration_dataset=calibration_dataset, - save_directory=save_dir, -) -# Save the tokenizer -tokenizer.save_pretrained(save_dir) ``` -The `quantize()` method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device. -## Weight-only quantization +The `quantize()` method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device. -You can optimize the performance of text-generation LLMs by quantizing weights to various precisions that provide different performance-accuracy trade-offs. -```python -from optimum.intel import OVModelForCausalLM +### Hybrid quantization -model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True) -``` +Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion (SD) models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights. +The U-Net component takes up most of the overall execution time of the pipeline. Thus, optimizing just this one component can bring substantial benefits in terms of inference speed while keeping acceptable accuracy without fine-tuning. Quantizing the rest of the diffusion pipeline does not significantly improve inference performance but could potentially lead to substantial accuracy degradation. +Therefore, the proposal is to apply quantization in *hybrid mode* for the U-Net model and weight-only quantization for the rest of the pipeline components : +* U-Net : quantization applied on both the weights and activations +* The text encoder, VAE encoder / decoder : quantization applied on the weights - - -`load_in_8bit` is enabled by default for the models larger than 1 billion parameters. - - +The hybrid mode involves the quantization of weights in MatMul and Embedding layers, and activations of other layers, facilitating accuracy preservation post-optimization while reducing the model size. -For the 4-bit weight quantization you can use the `quantization_config` to specify the optimization parameters, for example: +The `quantization_config` is utilized to define optimization parameters for optimizing the SD pipeline. To enable hybrid quantization, specify the quantization dataset in the `quantization_config`. If the dataset is not defined, weight-only quantization will be applied on all components. ```python -from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig +from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig -model = OVModelForCausalLM.from_pretrained( +model = OVStableDiffusionPipeline.from_pretrained( model_id, - quantization_config=OVWeightQuantizationConfig(bits=4), + export=True, + quantization_config=OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions"), ) ``` -You can tune quantization parameters to achieve a better performance accuracy trade-off as follows: - -```python -from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig - -model = OVModelForCausalLM.from_pretrained( - model_id, - quantization_config=OVWeightQuantizationConfig(bits=4, sym=False, ratio=0.8, dataset="ptb"), -) -``` For more details, please refer to the corresponding NNCF [documentation](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/CompressWeights.md). -## Training-time optimization +## Training-time Apart from optimizing a model after training like post-training quantization above, `optimum.openvino` also provides optimization methods during training, namely Quantization-Aware Training (QAT) and Joint Pruning, Quantization and Distillation (JPQD). diff --git a/optimum/exporters/openvino/__init__.py b/optimum/exporters/openvino/__init__.py index 9664f6ae6d..94ea4f103b 100644 --- a/optimum/exporters/openvino/__init__.py +++ b/optimum/exporters/openvino/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import optimum.exporters.openvino.model_configs + from .__main__ import main_export from .convert import export, export_from_model, export_models, export_pytorch_via_onnx from .stateful import ensure_stateful_is_available, patch_stateful diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 1c695e2f19..02268a3604 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -58,7 +58,7 @@ def main_export( local_files_only: bool = False, use_auth_token: Optional[Union[bool, str]] = None, model_kwargs: Optional[Dict[str, Any]] = None, - custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None, @@ -112,11 +112,11 @@ def main_export( when running `transformers-cli login` (stored in `~/.huggingface`). model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during - the export. This argument should be used along the `custom_onnx_configs` argument + the export. This argument should be used along the `custom_export_configs` argument in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). - custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): - Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). + custom_export_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): + Experimental usage: override the default export config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). fn_get_submodels (`Optional[Callable]`, defaults to `None`): Experimental usage: Override the default submodels that are used at the export. This is especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. @@ -134,7 +134,7 @@ def main_export( ```python >>> from optimum.exporters.openvino import main_export - >>> main_export("gpt2", output="gpt2_onnx/") + >>> main_export("gpt2", output="gpt2_ov/") ``` """ @@ -206,14 +206,14 @@ def main_export( if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True elif task not in TasksManager.get_supported_tasks_for_model_type( - model_type, exporter="onnx", library_name=library_name + model_type, exporter="openvino", library_name=library_name ): if original_task == "auto": autodetected_message = " (auto-detected)" else: autodetected_message = "" model_tasks = TasksManager.get_supported_tasks_for_model_type( - model_type, exporter="onnx", library_name=library_name + model_type, exporter="openvino", library_name=library_name ) raise ValueError( f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." @@ -288,7 +288,7 @@ class StoreAttr(object): not custom_architecture and library_name != "diffusers" and task + "-with-past" - in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx", library_name=library_name) + in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name) ): # Make -with-past the default if --task was not explicitely specified if original_task == "auto": @@ -319,7 +319,7 @@ class StoreAttr(object): ov_config=ov_config, stateful=stateful, model_kwargs=model_kwargs, - custom_onnx_configs=custom_onnx_configs, + custom_export_configs=custom_export_configs, fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, device=device, diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 5353912d48..dfca80f001 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -32,10 +32,11 @@ from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx +from optimum.exporters.utils import _get_submodels_and_export_configs from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available from optimum.utils.save_utils import maybe_save_preprocessors -from ...intel.utils.import_utils import is_nncf_available, is_optimum_version +from ...intel.utils.import_utils import is_nncf_available from .model_patcher import patch_model_with_bettertransformer from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful from .utils import ( @@ -48,13 +49,6 @@ ) -if is_optimum_version(">=", "1.16.99"): - from optimum.exporters.onnx.utils import _get_submodels_and_onnx_configs - -else: - from optimum.exporters.onnx.__main__ import _get_submodels_and_onnx_configs - - UNSUPPORTED_TOKENIZER_CLASSES = (T5Tokenizer, T5TokenizerFast) @@ -418,7 +412,7 @@ def ts_patched_forward(*args, **kwargs): def export_models( - models_and_onnx_configs: Dict[ + models_and_export_configs: Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] ], output_dir: Path, @@ -434,7 +428,7 @@ def export_models( Export the models to OpenVINO IR format Args: - models_and_onnx_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]): + models_and_export_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]): output_dir (Path): output directory for saving models opset (Optional[int], optional, Default to None): ONNX export opset output_names (Optional[List[str]], optional, Defaults to None): model output names @@ -459,20 +453,20 @@ def export_models( outputs = [] - if output_names is not None and len(output_names) != len(models_and_onnx_configs): + if output_names is not None and len(output_names) != len(models_and_export_configs): raise ValueError( - f"Provided custom names {output_names} for the export of {len(models_and_onnx_configs)} models. Please provide the same number of names as models to export." + f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export." ) - for i, model_name in enumerate(models_and_onnx_configs.keys()): - submodel, sub_onnx_config = models_and_onnx_configs[model_name] + for i, model_name in enumerate(models_and_export_configs.keys()): + submodel, sub_export_config = models_and_export_configs[model_name] output_name = output_names[i] if output_names is not None else Path(model_name + ".xml") output_path = output_dir / output_name output_path.parent.mkdir(parents=True, exist_ok=True) outputs.append( export( model=submodel, - config=sub_onnx_config, + config=sub_export_config, output=output_path, opset=opset, device=device, @@ -495,7 +489,7 @@ def export_from_model( stateful: bool = True, opset: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, - custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, preprocessors: List = None, device: str = "cpu", @@ -524,14 +518,14 @@ def export_from_model( task = TasksManager._infer_task_from_model_or_model_class(model=model) except (ValueError, KeyError) as e: raise RuntimeError( - f"The model task could not be automatically inferred in `onnx_export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" + f"The model task could not be automatically inferred in `export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) if ( not custom_architecture and library_name != "diffusers" and task + "-with-past" - in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx", library_name=library_name) + in TasksManager.get_supported_tasks_for_model_type(model_type, "openvino", library_name=library_name) ): # -with-past is the default. task = task + "-with-past" @@ -541,9 +535,9 @@ def export_from_model( stateful = stateful and ensure_export_task_support_stateful(task) # TODO: support onnx_config.py in the model repo - if custom_architecture and custom_onnx_configs is None: + if custom_architecture and custom_export_configs is None: raise ValueError( - f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export." + f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export." ) if task.startswith("text-generation") and model.config.is_encoder_decoder: @@ -569,11 +563,11 @@ def export_from_model( kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] ) - onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs( + export_config, models_and_export_configs = _get_submodels_and_export_configs( model=model, task=task, monolith=False, - custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, + custom_export_configs=custom_export_configs if custom_export_configs is not None else {}, custom_architecture=custom_architecture, fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, @@ -581,6 +575,7 @@ def export_from_model( model_kwargs=model_kwargs, _variant="default", legacy=False, + exporter="openvino", ) if ov_config is None: @@ -612,18 +607,18 @@ def export_from_model( model_name_or_path = model.config._name_or_path maybe_save_preprocessors(model_name_or_path, output) - files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_onnx_configs.keys()] + files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()] else: # save the subcomponent configuration - for model_name in models_and_onnx_configs: - subcomponent = models_and_onnx_configs[model_name][0] + for model_name in models_and_export_configs: + subcomponent = models_and_export_configs[model_name][0] if hasattr(subcomponent, "save_config"): subcomponent.save_config(output / model_name) elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"): subcomponent.config.save_pretrained(output / model_name) - files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs] + files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_export_configs] # Saving the additional components needed to perform inference. model.scheduler.save_pretrained(output.joinpath("scheduler")) @@ -643,7 +638,7 @@ def export_from_model( model.save_config(output) export_models( - models_and_onnx_configs=models_and_onnx_configs, + models_and_export_configs=models_and_export_configs, output_dir=output, output_names=files_subpaths, input_shapes=input_shapes, diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py new file mode 100644 index 0000000000..b6536512b1 --- /dev/null +++ b/optimum/exporters/openvino/model_configs.py @@ -0,0 +1,391 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +from packaging import version +from transformers.utils import is_tf_available + +from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig +from optimum.exporters.onnx.model_configs import GemmaOnnxConfig +from optimum.exporters.tasks import TasksManager +from optimum.utils import DEFAULT_DUMMY_SHAPES +from optimum.utils.input_generators import ( + DummyInputGenerator, + DummyPastKeyValuesGenerator, + DummyTextInputGenerator, + MistralDummyPastKeyValuesGenerator, +) +from optimum.utils.normalized_config import NormalizedTextConfig + +from .model_patcher import ( + BaichuanModelPatcher, + ChatGLMModelPatcher, + GemmaModelPatcher, + MixtralModelPatcher, + QwenModelPatcher, +) + + +def init_model_configs(): + supported_model_types = [ + "_SUPPORTED_MODEL_TYPE", + "_DIFFUSERS_SUPPORTED_MODEL_TYPE", + "_TIMM_SUPPORTED_MODEL_TYPE", + "_SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE", + ] + + for supported_models_config in supported_model_types: + supported_models = getattr(TasksManager, supported_models_config) + for model, export_configs in supported_models.items(): + if "onnx" not in export_configs: + continue + onnx_config = export_configs["onnx"] + supported_models[model]["openvino"] = deepcopy(onnx_config) + + setattr(TasksManager, supported_models_config, supported_models) + + +init_model_configs() + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + from optimum.exporters.onnx.model_patcher import ModelPatcher + + if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel + + +register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True) + + +@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers") +class BaichaunOpenVINOConfig(TextDecoderOnnxConfig): + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size" + ) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return BaichuanModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"], library_name="transformers") +class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.multi_query_group_num = normalized_config.multi_query_group_num + self.head_dim = normalized_config.kv_channels + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + past_key_shape = ( + self.sequence_length, + self.batch_size, + self.multi_query_group_num, + self.head_dim, + ) + past_value_shape = ( + self.sequence_length, + self.batch_size, + self.multi_query_group_num, + self.head_dim, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class ChatGLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers") + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] + if self.use_past_in_inputs and self.use_cache_branch is not False: + input_names.append("past_key_values") + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + # refer to https://github.com/huggingface/optimum/pull/764 + if ( + self.use_past_in_inputs + and self.PAD_ATTENTION_MASK_TO_PAST + and self.use_cache_branch is not False + and "attention_mask" in dummy_inputs + ): + # Obtain the past sequence length from the value instead of the key (Bloom). ChatGLM has seq_len in 0 dim instead of -2 + past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[0] + + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( + dummy_inputs["attention_mask"], + desired_length=past_present_length, + dim=1, + dtype=dummy_inputs["attention_mask"].dtype, + ) + + return dummy_inputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. + + Args: + inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill. + direction (`str`): + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the + output mapping, this is important for axes naming. + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + present_lenght" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name} + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"], library_name="transformers") +class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 + MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") + + # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 + DEFAULT_ONNX_OPSET = 14 + DUMMY_INPUT_GENERATOR_CLASSES = ( + MistralDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MixtralModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "gemma", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class GemmaOpenVINOConfig(GemmaOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return GemmaModelPatcher(self, model, model_kwargs=model_kwargs) + + +class QwenDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.kv_channels = normalized_config.kv_channels + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + past_key_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels) + past_value_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("qwen", *["text-generation", "text-generation-with-past"]) +class QwenOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size" + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, QwenDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = QwenDummyPastKeyValuesGenerator + no_position_ids = False + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] + if self.use_past_in_inputs and self.use_cache_branch is not False: + input_names.append("past_key_values") + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + # refer to https://github.com/huggingface/optimum/pull/764 + if ( + self.use_past_in_inputs + and self.PAD_ATTENTION_MASK_TO_PAST + and self.use_cache_branch is not False + and "attention_mask" in dummy_inputs + ): + # Obtain the past sequence length from the value instead of the key (Bloom). Qwen has seq_len in 1 dim instead of -2 + past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[1] + + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( + dummy_inputs["attention_mask"], + desired_length=past_present_length, + dim=1, + dtype=dummy_inputs["attention_mask"].dtype, + ) + + return dummy_inputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. + + Args: + inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill. + direction (`str`): + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the + output mapping, this is important for axes naming. + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 1: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 1: decoder_sequence_name} + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return QwenModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index aea57161e2..371fee732a 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,23 +13,43 @@ # limitations under the License. import logging as log +import types +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +import torch +import torch.nn.functional as F +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils import is_tf_available + +from optimum.exporters.onnx.model_patcher import DecoderModelPatcher from optimum.intel.utils.import_utils import ( + _openvino_version, _torch_version, _transformers_version, + is_openvino_version, is_torch_version, is_transformers_version, ) +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + from optimum.exporters.onnx.config import OnnxConfig + + if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel + + def patch_model_with_bettertransformer(model): + COLOR_RED = "\033[1;31m" + COLOR_RESET = "\033[0m" + # check that the model has not yet been pathced if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: return model if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"): - COLOR_RED = "\033[1;31m" - COLOR_RESET = "\033[0m" log.warn( COLOR_RED + "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. " @@ -39,6 +59,22 @@ def patch_model_with_bettertransformer(model): + COLOR_RESET ) + if ( + getattr(model.config, "model_type") in {"gpt_bigcode", "llama"} + and is_transformers_version(">=", "4.38") + and is_openvino_version("<", "2024.1.0-14612") + ): + # display commit-id only when a nightly/prerelease of OpenVINO is installed. + display_version = ( + _openvino_version.split("-")[0] if is_openvino_version("<=", "2024.0.0-14509") else _openvino_version + ) + log.warn( + COLOR_RED + f"[WARNING] Stateful models are not supported for Llama and GPTBigCode with Transformers " + f"{_transformers_version} and OpenVINO {display_version}. For good performance, consider using a nightly OpenVINO build: " + "https://docs.openvino.ai/2024/get-started/install-openvino.html. For models that do not need transformers " + "4.38+, it is also an option to downgrade transformers: `pip install transformers==4.37.2`" + COLOR_RESET + ) + # model already has required SDPA implementation if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": return model @@ -52,3 +88,425 @@ def patch_model_with_bettertransformer(model): return model return model + + +def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + layer.block_sparse_moe._unpatched_forward = layer.block_sparse_moe.forward + layer.block_sparse_moe.forward = types.MethodType( + _mixtral_sparse_moe_block_forward, layer.block_sparse_moe + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward + + +def _chatglm_transformer_forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if past_key_values is not None: + full_attention_mask = torch.ones( + batch_size, + seq_length, + seq_length, + device=input_ids.device, + dtype=torch.float, + ) * float("-inf") + full_attention_mask.triu_(diagonal=1) + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.zeros(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + full_attention_mask.unsqueeze_(1) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor): + mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype) + if query_layer.shape[2] == key_layer.shape[2]: + tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1) + mask.masked_fill_(tmp_mask, float("-inf")) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attn_mask=mask + ) + return context_layer + + +def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None: + context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer) + else: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + + return context_layer + + +class ChatGLMModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + self.original_chatglm_transformer_forward = model.transformer.forward + + def __enter__(self): + super().__enter__() + self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer) + for block in self._model.transformer.encoder.layers: + block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward + block.self_attention.core_attention.forward = types.MethodType( + _chatglm2_core_attention_forward, block.self_attention.core_attention + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.transformer.forward = self.original_chatglm_transformer_forward + for block in self._model.transformer.encoder.layers: + block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward + + +class GemmaModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + # init inv_freq for torchscript tracing + # https://github.com/huggingface/transformers/blob/ed74d97871468f3a4695ede50abdc0b55717a84d/src/transformers/models/gemma/modeling_gemma.py#L108 + for layer in self._model.model.layers: + if layer.self_attn.rotary_emb.inv_freq is None: + rotary_emb = layer.self_attn.rotary_emb + layer.self_attn.rotary_emb.inv_freq = 1.0 / ( + rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) + ) + + +SUPPORT_SDPA = is_torch_version(">", "2.1.0") + + +def _qwen_rotate_half(x): + from einops import rearrange + + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def _qwen_apply_rotary_pos_emb(t, freqs): + cos, sin = freqs + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] + t_ = t_.float() + t_pass_ = t_pass_.float() + t_ = (t_ * cos) + (_qwen_rotate_half(t_) * sin) + return torch.cat((t_, t_pass_), dim=-1).type_as(t) + + +def _qwen_quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() + zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + + +def _qwen_attention_forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, +): + mixed_x_layer = self.c_attn(hidden_states) + + query, key, value = mixed_x_layer.split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if rotary_pos_emb_list is not None: + cur_len = query.shape[1] + if len(rotary_pos_emb_list) == 1: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = _qwen_apply_rotary_pos_emb(query, q_pos_emb) + key = _qwen_apply_rotary_pos_emb(key, k_pos_emb) + else: + query_list = [] + key_list = [] + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [_qwen_apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)] + key_list += [_qwen_apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)] + query = torch.cat(query_list, dim=0) + key = torch.cat(key_list, dim=0) + + if self.use_cache_quantization: + key = _qwen_quantize_cache_v(key.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) + value = _qwen_quantize_cache_v(value.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = ( + torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2), + ) + value = ( + torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2), + ) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache: + present = (key, value) + else: + present = None + + if self.use_logn_attn and not self.training: + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) + query = query * logn_tensor.expand_as(query) + + if self.use_flash_attn and not self.is_fp32 and query.is_cuda: + q, k, v = query, key, value + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) + else: + registered_causal_mask = torch.tril( + torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device) + ).view(1, 1, key.size(1), key.size(1)) + query = query.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + if not self.use_cache_quantization and SUPPORT_SDPA: + causal_mask = registered_causal_mask[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)] + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1).masked_fill( + ~causal_mask, torch.finfo(query.dtype).min + ) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn(query, key, value, registered_causal_mask, attention_mask, head_mask) + context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + attn_output = self.c_proj(context_layer) + + outputs = (attn_output, present) + if output_attentions: + if self.use_flash_attn and not self.is_fp32: + raise ValueError("Cannot output attentions while using flash-attn") + else: + outputs += (attn_weight,) + + return outputs + + +class QwenModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + self.original_fp16 = model.config.fp16 + self.original_bf16 = model.config.bf16 + model.config.bf16 = False + model.config.fp16 = False + if self.original_fp16 or self.original_bf16: + model.to(torch.float32) + model.transformer.rotary_emb(2048) + + def __enter__(self): + super().__enter__() + for block in self._model.transformer.h: + block.attn._orig_forward = block.attn.forward + block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.transformer.h: + block.attn.forward = block.attn._orig_forward + self._model.config.bf16 = self.original_bf16 + self._model.config.fp16 = self.original_fp16 + + +class BaichuanModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + # model has first inference buffers initialization + if self._model.lm_head.first_flag: + self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index 8ddd005279..40a60bb58e 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -167,7 +167,7 @@ class OVWeightQuantizationConfig(QuantizationConfigMixin): bits (`int`, defaults to 8): The number of bits to quantize to. - sym (`bool`, *optional*, defaults to `False`): + sym (`bool`, defaults to `False`): Whether to use symetric quantization. tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*): The tokenizer used to process the dataset. You can pass either: @@ -177,23 +177,24 @@ class OVWeightQuantizationConfig(QuantizationConfigMixin): user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. - dataset (`Union[List[str]]`, *optional*): - The dataset used for data-aware compression. You can provide your own dataset in a list of string or just use the - the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] - group_size (`int`, *optional*, defaults to 128): - The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. - ratio (`float`, *optional*, defaults to 1.0): + dataset (`str or List[str]`, *optional*): + The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset + in a list of strings or just use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs + or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models. + ratio (`float`, defaults to 1.0): The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM and the rest to INT8_ASYM). + group_size (`int`, *optional*): + The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. all_layers (`bool`, *optional*): Defines how many layers are compressed to 4-bits while the rest are kept in 8-bit presicion. - sensitivity_metric (`nncf.SensitivityMetric`, *optional*): + sensitivity_metric (`str`, *optional*): The sensitivity metric for assigning quantization precision to layers. In order to preserve the accuracy of the model, the more sensitive layers receives a higher precision. - awq (`bool`, *optional*): - Enables AWQ method to unify weight ranges and improve overall model accuracy. - ignored_scope (`nncf.IgnoredScope`, *optional*): + ignored_scope (`dict`, *optional*): An ignored scope that defined the list of model control flow graph nodes to be ignored during quantization. + num_samples (`int`, *optional*): + The maximum number of samples composing the calibration dataset. """ @@ -202,12 +203,13 @@ def __init__( bits: int = 8, sym: bool = False, tokenizer: Optional[Any] = None, - dataset: Optional[str] = None, + dataset: Optional[Union[str, List[str]]] = None, ratio: float = 1.0, group_size: Optional[int] = None, all_layers: Optional[bool] = None, sensitivity_metric: Optional[str] = None, ignored_scope: Optional[dict] = None, + num_samples: Optional[int] = None, **kwargs, ): self.bits = bits @@ -219,6 +221,7 @@ def __init__( self.all_layers = all_layers self.sensitivity_metric = sensitivity_metric self.ignored_scope = ignored_scope + self.num_samples = num_samples self.quant_method = "default" # TODO : enable AWQ after nncf v2.9.0 release self.post_init() @@ -231,10 +234,16 @@ def post_init(self): if self.group_size is not None and self.group_size != -1 and self.group_size <= 0: raise ValueError("`group_size` must be greater than 0 or equal to -1") if self.dataset is not None and isinstance(self.dataset, str): - if self.dataset not in ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]: + llm_datasets = ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"] + stable_diffusion_datasets = [ + "conceptual_captions", + "laion/220k-GPT4Vision-captions-from-LIVIS", + "laion/filtered-wit", + ] + if self.dataset not in llm_datasets + stable_diffusion_datasets: raise ValueError( f"""You have entered a string value for dataset. You can only choose between - ['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}""" + {llm_datasets} for LLLMs or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}""" ) if self.bits not in [4, 8]: diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 7ab99aab42..15f1fc4f1c 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -388,7 +388,7 @@ def compile(self): if ( "CACHE_DIR" not in self.ov_config.keys() and not str(self.model_save_dir).startswith(gettempdir()) - and self._device.lower() == "gpu" + and "gpu" in self._device.lower() ): # Set default CACHE_DIR only if it is not set, if the model is not in a temporary directory, and device is GPU cache_dir = Path(self.model_save_dir).joinpath("model_cache") diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index edc88d02cb..832c132615 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -316,7 +316,9 @@ def _reshape( shapes[inputs][0] = -1 input_name = inputs.get_any_name() if input_name.startswith("past_key_values"): - if len(inputs.partial_shape) == 3 and input_name.endswith("value"): + if ( + len(inputs.partial_shape) == 3 and input_name.endswith("value") + ) or self.config.model_type == "chatglm": shapes[inputs][1] = -1 else: shapes[inputs][2] = -1 @@ -635,7 +637,8 @@ def _from_pretrained( # from optimum.gptq.utils import get_seqlen # seqlen = get_seqlen(causal_model) - dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32) + nsamples = quantization_config.num_samples if quantization_config.num_samples else 128 + dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples) dataset = prepare_dataset(dataset) quantization_config = copy.deepcopy(quantization_config) quantization_config.dataset = nncf.Dataset(dataset, lambda x: causal_model.prepare_inputs(**x)) diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index a985f43d7c..f0fea5a8ce 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -16,6 +16,7 @@ import logging import os import shutil +from copy import deepcopy from pathlib import Path from tempfile import TemporaryDirectory, gettempdir from typing import Any, Dict, List, Optional, Union @@ -57,7 +58,13 @@ from .configuration import OVConfig, OVWeightQuantizationConfig from .loaders import OVTextualInversionLoaderMixin from .modeling_base import OVBaseModel -from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, _print_compiled_model_properties +from .utils import ( + ONNX_WEIGHTS_NAME, + OV_TO_NP_TYPE, + OV_XML_FILE_NAME, + PREDEFINED_SD_DATASETS, + _print_compiled_model_properties, +) core = Core() @@ -274,9 +281,19 @@ def _from_pretrained( kwargs[name] = load_method(new_model_save_dir) quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit) - unet = cls.load_model( - new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, quantization_config - ) + + unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name + if quantization_config is not None and quantization_config.dataset is not None: + # load the UNet model uncompressed to apply hybrid quantization further + unet = cls.load_model(unet_path) + # Apply weights compression to other `components` without dataset + weight_quantization_params = { + param: value for param, value in quantization_config.__dict__.items() if param != "dataset" + } + weight_quantization_config = OVWeightQuantizationConfig.from_dict(weight_quantization_params) + else: + weight_quantization_config = quantization_config + unet = cls.load_model(unet_path, weight_quantization_config) components = { "vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, @@ -286,11 +303,29 @@ def _from_pretrained( } for key, value in components.items(): - components[key] = cls.load_model(value, quantization_config) if value.is_file() else None + components[key] = cls.load_model(value, weight_quantization_config) if value.is_file() else None if model_save_dir is None: model_save_dir = new_model_save_dir + if quantization_config is not None and quantization_config.dataset is not None: + sd_model = cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs) + + supported_pipelines = ( + OVStableDiffusionPipeline, + OVStableDiffusionXLPipeline, + OVLatentConsistencyModelPipeline, + ) + if not isinstance(sd_model, supported_pipelines): + raise NotImplementedError(f"Quantization in hybrid mode is not supported for {cls.__name__}") + + nsamples = quantization_config.num_samples if quantization_config.num_samples else 200 + unet_inputs = sd_model._prepare_unet_inputs(quantization_config.dataset, nsamples) + + from .quantization import _hybrid_quantization + + unet = _hybrid_quantization(sd_model.unet.model, weight_quantization_config, dataset=unet_inputs) + return cls( unet=unet, config=config, @@ -300,6 +335,62 @@ def _from_pretrained( **kwargs, ) + def _prepare_unet_inputs( + self, + dataset: Union[str, List[Any]], + num_samples: int, + height: Optional[int] = None, + width: Optional[int] = None, + seed: Optional[int] = 42, + **kwargs, + ) -> Dict[str, Any]: + self.compile() + + size = self.unet.config.get("sample_size", 64) * self.vae_scale_factor + height = height or min(size, 512) + width = width or min(size, 512) + + if isinstance(dataset, str): + dataset = deepcopy(dataset) + available_datasets = PREDEFINED_SD_DATASETS.keys() + if dataset not in available_datasets: + raise ValueError( + f"""You have entered a string value for dataset. You can only choose between + {list(available_datasets)}, but the {dataset} was found""" + ) + + from datasets import load_dataset + + dataset_metadata = PREDEFINED_SD_DATASETS[dataset] + dataset = load_dataset(dataset, split=dataset_metadata["split"], streaming=True).shuffle(seed=seed) + input_names = dataset_metadata["inputs"] + dataset = dataset.select_columns(list(input_names.values())) + + def transform_fn(data_item): + return {inp_name: data_item[column] for inp_name, column in input_names.items()} + + else: + + def transform_fn(data_item): + return data_item if isinstance(data_item, (list, dict)) else [data_item] + + from .quantization import InferRequestWrapper + + calibration_data = [] + self.unet.request = InferRequestWrapper(self.unet.request, calibration_data) + + for inputs in dataset: + inputs = transform_fn(inputs) + if isinstance(inputs, dict): + self.__call__(**inputs, height=height, width=width) + else: + self.__call__(*inputs, height=height, width=width) + if len(calibration_data) > num_samples: + break + + self.unet.request = self.unet.request.request + return calibration_data[:num_samples] + @classmethod def _from_transformers( cls, @@ -578,7 +669,7 @@ def _compile(self): if ( "CACHE_DIR" not in self.ov_config.keys() and not str(self._model_dir).startswith(gettempdir()) - and self.device.lower() == "gpu" + and self.device.lower().split(":")[0] == "gpu" ): self.ov_config["CACHE_DIR"] = os.path.join(self._model_dir, self._model_name, "model_cache") diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 617d898be5..d68cbc75ed 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -451,7 +451,7 @@ def _compile(self): if ( "CACHE_DIR" not in ov_config.keys() and not str(self.parent_model.model_save_dir).startswith(gettempdir()) - and self._device.lower() == "gpu" + and "gpu" in self._device.lower() ): cache_dir = Path(self.parent_model.model_save_dir).joinpath("model_cache") ov_config["CACHE_DIR"] = str(cache_dir) @@ -563,7 +563,7 @@ def _compile(self): if ( "CACHE_DIR" not in ov_config.keys() and not str(self.parent_model.model_save_dir).startswith(gettempdir()) - and self._device.lower() == "gpu" + and "gpu" in self._device.lower() ): cache_dir = Path(self.parent_model.model_save_dir).joinpath("model_cache") ov_config["CACHE_DIR"] = str(cache_dir) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index cd26f91f22..2022a495d8 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -16,6 +16,7 @@ import inspect import logging import os +from collections import deque from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -24,6 +25,7 @@ import torch import transformers from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric +from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters from nncf.torch import create_compressed_model, register_default_init_args, register_module from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk from nncf.torch.initialization import PTInitializingDataLoader @@ -348,7 +350,7 @@ def _quantize_torchmodel( model_type = self.model.config.model_type.replace("_", "-") onnx_config_class = TasksManager.get_exporter_config_constructor( - exporter="onnx", + exporter="openvino", model=self.model, task=self.task, model_type=model_type, @@ -550,7 +552,7 @@ def _remove_unused_columns(self, dataset: "Dataset"): def _weight_only_quantization( model: openvino.runtime.Model, quantization_config: Union[OVWeightQuantizationConfig, Dict] -): +) -> openvino.runtime.Model: config = quantization_config if isinstance(config, dict): config = OVWeightQuantizationConfig.from_dict(quantization_config) @@ -564,7 +566,8 @@ def _weight_only_quantization( from optimum.gptq.data import get_dataset, prepare_dataset - dataset = get_dataset(config.dataset, tokenizer, seqlen=32) + nsamples = config.num_samples if config.num_samples else 128 + dataset = get_dataset(config.dataset, tokenizer, seqlen=32, nsamples=nsamples) dataset = prepare_dataset(dataset) sensitivity_metric = None @@ -590,4 +593,92 @@ def _weight_only_quantization( # awq=config.quant_method == "awq", # TODO : remove and add it back once nncf v2.9.0 ignored_scope=ignored_scope, dataset=dataset, + # subset_size=config.num_samples if config.num_samples else 128, # TODO : enable from nncf v2.9.0 ) + + +def _get_operation_const_op(operation, const_port_id: int): + node = operation.input_value(const_port_id).get_node() + queue = deque([node]) + constant_node = None + allowed_propagation_types_list = ["Convert", "FakeQuantize", "Reshape"] + + while len(queue) != 0: + curr_node = queue.popleft() + if curr_node.get_type_name() == "Constant": + constant_node = curr_node + break + if len(curr_node.inputs()) == 0: + break + if curr_node.get_type_name() in allowed_propagation_types_list: + queue.append(curr_node.input_value(0).get_node()) + + return constant_node + + +def _is_embedding(node) -> bool: + allowed_types_list = ["f16", "f32", "f64"] + const_port_id = 0 + input_tensor = node.input_value(const_port_id) + if input_tensor.get_element_type().get_type_name() in allowed_types_list: + const_node = _get_operation_const_op(node, const_port_id) + if const_node is not None: + return True + + return False + + +def _collect_ops_with_weights(model): + ops_with_weights = [] + for op in model.get_ops(): + if op.get_type_name() == "MatMul": + constant_node_0 = _get_operation_const_op(op, const_port_id=0) + constant_node_1 = _get_operation_const_op(op, const_port_id=1) + if constant_node_0 or constant_node_1: + ops_with_weights.append(op.get_friendly_name()) + if op.get_type_name() == "Gather" and _is_embedding(op): + ops_with_weights.append(op.get_friendly_name()) + + return ops_with_weights + + +def _hybrid_quantization( + model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: Dict[str, Any] +) -> openvino.runtime.Model: + """ + Quantize a model in hybrid mode with NNCF which means that we quantize: + weights of MatMul and Embedding layers and activations of other layers. + The optimization specifications defined in `quantization_config`. + + Args: + model (`openvino.runtime.Model`): + The OpenVINO Runtime model for applying hybrid quantization. + quantization_config (`OVWeightQuantizationConfig`): + The configuration containing the parameters related to quantization. + dataset (`Dict[str, Any]`): + The dataset used for hybrid quantization. + Returns: + The OpenVINO Runtime model with applied hybrid quantization. + """ + ops_to_compress = _collect_ops_with_weights(model) + + ignored_scope = quantization_config.ignored_scope if isinstance(quantization_config.ignored_scope, dict) else {} + ptq_ignored_scope = nncf.IgnoredScope(**ignored_scope) + ptq_ignored_scope.names += ops_to_compress + + wc_quantization_config = copy.deepcopy(quantization_config) + wc_quantization_config.ignored_scope = ignored_scope + wc_quantization_config.ignored_scope["types"] = ignored_scope.get("types", []) + ["Convolution"] + compressed_model = _weight_only_quantization(model, wc_quantization_config) + + subset_size = quantization_config.num_samples if quantization_config.num_samples else 200 + quantized_model = nncf.quantize( + model=compressed_model, + calibration_dataset=nncf.Dataset(dataset), + model_type=nncf.ModelType.TRANSFORMER, + ignored_scope=ptq_ignored_scope, + # The SQ algo should be disabled for MatMul nodes because their weights are already compressed + advanced_parameters=nncf.AdvancedQuantizationParameters(AdvancedSmoothQuantParameters(matmul=-1)), + subset_size=subset_size, + ) + return quantized_model diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py index 49aec81e57..a0439d2129 100644 --- a/optimum/intel/openvino/utils.py +++ b/optimum/intel/openvino/utils.py @@ -20,7 +20,7 @@ import numpy as np from huggingface_hub import model_info -from openvino.runtime import Type, properties +from openvino.runtime import Core, Type, properties from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size @@ -99,6 +99,13 @@ } +PREDEFINED_SD_DATASETS = { + "conceptual_captions": {"split": "train", "inputs": {"prompt": "caption"}}, + "laion/220k-GPT4Vision-captions-from-LIVIS": {"split": "train", "inputs": {"prompt": "caption"}}, + "laion/filtered-wit": {"split": "train", "inputs": {"prompt": "caption"}}, +} + + def use_external_data_format(num_parameters: int) -> bool: """ Returns whether or not the model requires using external data format for the ONNX export @@ -148,3 +155,9 @@ def _print_compiled_model_properties(compiled_model): logger.info(f" {k}: {value}") except Exception: logger.error(f"[error] Get property of '{k}' failed") + try: + logger.info("EXECUTION_DEVICES:") + for device in compiled_model.get_property("EXECUTION_DEVICES"): + logger.info(f" {device}: {Core().get_property(device, 'FULL_DEVICE_NAME')}") + except Exception: + logger.error("[error] Get FULL_DEVICE_NAME failed") diff --git a/setup.py b/setup.py index ac4056c30d..5c6cf76404 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,6 @@ +import os import re +import subprocess from setuptools import find_namespace_packages, setup @@ -8,13 +10,26 @@ filepath = "optimum/intel/version.py" with open(filepath) as version_file: (__version__,) = re.findall('__version__ = "(.*)"', version_file.read()) + if __version__.endswith(".dev0"): + dev_version_id = "" + try: + repo_root = os.path.dirname(os.path.realpath(__file__)) + dev_version_id = ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], cwd=repo_root) # nosec + .strip() + .decode() + ) + dev_version_id = "+" + dev_version_id + except subprocess.CalledProcessError: + pass + __version__ = __version__ + dev_version_id except Exception as error: assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ "torch>=1.11", - "optimum~=1.17", "transformers>=4.36.0,<4.39.0", + "optimum @ git+https://github.com/huggingface/optimum.git#egg=optimum", "datasets>=1.4.0", "sentencepiece", "scipy", @@ -35,6 +50,8 @@ "timm", "invisible-watermark>=0.2.0", "auto-gptq", + "transformers_stream_generator", + "einops", ] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"] diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 2188b7061f..9df6c73214 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -28,6 +28,7 @@ from parameterized import parameterized from PIL import Image from transformers import ( + AutoConfig, AutoFeatureExtractor, AutoModel, AutoModelForAudioClassification, @@ -52,7 +53,6 @@ from transformers.onnx.utils import get_preprocessor from utils_tests import MODEL_NAMES -from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.intel import ( OVModelForAudioClassification, OVModelForAudioFrameClassification, @@ -473,73 +473,101 @@ def test_pipeline(self, model_arch): class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", + "baichuan2", "gpt_bigcode", "blenderbot", "blenderbot-small", "bloom", + "chatglm", "codegen", # "data2vec-text", # TODO : enable when enabled in exporters + "gemma", "gpt2", "gpt_neo", "gpt_neox", "llama", # "llama_gptq", "marian", + "minicpm", "mistral", + "mixtral", "mpt", "opt", "pegasus", + "qwen", + "qwen2", + "stablelm", ) GENERATION_LENGTH = 100 IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3") + REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen") @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] + not_stateful = ["gpt_bigcode"] + if is_openvino_version("<", "2024.0"): + not_stateful.append("mixtral") + + if is_openvino_version("<", "2024.1"): + not_stateful.extend(["llama", "gemma"]) if "gptq" in model_arch: self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") set_seed(SEED) - ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) + + model_kwargs = {} + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs) self.assertIsInstance(ov_model.config, PretrainedConfig) self.assertTrue(ov_model.use_cache) - - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) + self.assertEqual( + ov_model.stateful, self.IS_SUPPORT_STATEFUL and ov_model.config.model_type not in not_stateful + ) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + if model_arch == "qwen": + transformers_model.to(torch.float32) tokens = tokenizer( "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None ) - position_ids = None - if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: - input_shape = tokens["input_ids"].shape - position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) - ov_outputs = ov_model(**tokens, position_ids=position_ids) + ov_outputs = ov_model(**tokens) self.assertTrue("logits" in ov_outputs) self.assertIsInstance(ov_outputs.logits, torch.Tensor) self.assertTrue("past_key_values" in ov_outputs) self.assertIsInstance(ov_outputs.past_key_values, tuple) - - is_stateful = ov_model.config.model_type not in {"gpt_bigcode", "llama"} and self.IS_SUPPORT_STATEFUL + is_stateful = ov_model.config.model_type not in not_stateful and self.IS_SUPPORT_STATEFUL self.assertEqual(ov_model.stateful, is_stateful) if is_stateful: self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) - with torch.no_grad(): transformers_outputs = transformers_model(**tokens) # Compare tensor outputs - self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=1e-4)) del transformers_model del ov_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + model_kwargs = {} model_id = MODEL_NAMES[model_arch] - tokenizer = AutoTokenizer.from_pretrained(model_id) - model = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=False, compile=False) + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + model = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=False, compile=False, **model_kwargs + ) model.config.encoder_no_repeat_ngram_size = 0 model.to("cpu") model.half() @@ -556,8 +584,16 @@ def test_pipeline(self, model_arch): def test_multiple_inputs(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False) - tokenizer = AutoTokenizer.from_pretrained(model_id) + if model_arch == "qwen": + self.skipTest("Qwen tokenizer does not support padding") + model_kwargs = {} + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) tokenizer.pad_token = tokenizer.eos_token texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"] tokens = tokenizer(texts, padding=True, return_tensors="pt") diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 57c45df6ec..c7fb00e12d 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -39,6 +39,7 @@ from optimum.intel import ( OVConfig, + OVLatentConsistencyModelPipeline, OVModelForAudioClassification, OVModelForCausalLM, OVModelForFeatureExtraction, @@ -157,10 +158,10 @@ class OVWeightCompressionTest(unittest.TestCase): (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 44), ) - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 62, 365),) - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 0, 385),) + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 62, 86),) + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 0, 148),) SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS = ( - (OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 14, 136), + (OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 14, 50), ) SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = ( (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 44), @@ -233,6 +234,12 @@ class OVWeightCompressionTest(unittest.TestCase): (OVStableDiffusionXLPipeline, "stable-diffusion-xl"), ) + SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION = ( + (OVStableDiffusionPipeline, "stable-diffusion", 72, 195), + (OVStableDiffusionXLPipeline, "stable-diffusion-xl", 84, 331), + (OVLatentConsistencyModelPipeline, "latent-consistency", 50, 135), + ) + IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3") DEFAULT_INT4_CONFIG = {"bits": 4, "sym": True, "group_size": 64, "all_layers": True} @@ -352,6 +359,38 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_ov_int8[i], num_int8) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION) + def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8): + model_id = MODEL_NAMES[model_type] + quantization_config = OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions", num_samples=2) + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config) + + num_fake_quantize, num_int8, num_int4 = get_num_quantized_nodes(model.unet) + self.assertEqual(expected_num_fake_quantize, num_fake_quantize) + self.assertEqual(expected_ov_int8, num_int8) + self.assertEqual(0, num_int4) + + model.save_pretrained(tmp_dir) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION[-1:]) + def test_ovmodel_hybrid_quantization_with_custom_dataset( + self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8 + ): + model_id = MODEL_NAMES[model_type] + dataset = [ + "dream rose covered with clean crystal, sharp edges, transparent, beautiful, highly detailed, high render" + ] + model = model_cls.from_pretrained( + model_id, + export=True, + quantization_config=OVWeightQuantizationConfig(bits=8, dataset=dataset, num_samples=3), + ) + num_fake_quantize, num_int8, num_int4 = get_num_quantized_nodes(model.unet) + self.assertEqual(expected_num_fake_quantize, num_fake_quantize) + self.assertEqual(expected_ov_int8, num_int8) + self.assertEqual(0, num_int4) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS) @unittest.mock.patch.dict( "optimum.intel.openvino.configuration._DEFAULT_4BIT_CONFIGS", {"facebook/opt-125m": DEFAULT_INT4_CONFIG} diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 04049172d3..ad3cd03d3d 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -22,12 +22,14 @@ "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", "bert": "hf-internal-testing/tiny-random-bert", "bart": "hf-internal-testing/tiny-random-bart", + "baichuan2": "katuni4ka/tiny-random-baichuan2", "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", + "chatglm": "katuni4ka/tiny-random-chatglm2", "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel", @@ -38,6 +40,7 @@ "convnext": "hf-internal-testing/tiny-random-convnext", "distilbert": "hf-internal-testing/tiny-random-distilbert", "electra": "hf-internal-testing/tiny-random-electra", + "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", @@ -55,7 +58,9 @@ "opt125m": "facebook/opt-125m", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", + "minicpm": "katuni4ka/tiny-random-minicpm", "mistral": "echarlaix/tiny-random-mistral", + "mixtral": "TitanML/tiny-mixtral", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", @@ -66,6 +71,8 @@ "pegasus": "hf-internal-testing/tiny-random-pegasus", "pix2struct": "fxmarty/pix2struct-tiny-random", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", + "qwen": "katuni4ka/tiny-random-qwen", + "qwen2": "Qwen/Qwen1.5-0.5B", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", @@ -76,6 +83,7 @@ "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "stable-diffusion-xl-refiner": "echarlaix/tiny-random-stable-diffusion-xl-refiner", + "stablelm": "hf-internal-testing/tiny-random-StableLmForCausalLM", "latent-consistency": "echarlaix/tiny-random-latent-consistency", "sew": "hf-internal-testing/tiny-random-SEWModel", "sew_d": "asapp/sew-d-tiny-100k-ft-ls100h", @@ -116,7 +124,7 @@ "stable-diffusion-xl-refiner": (366, 34, 42, 66), } -_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (62, 477)} +_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (62, 86)} def get_num_quantized_nodes(ov_model): @@ -127,8 +135,8 @@ def get_num_quantized_nodes(ov_model): if "FakeQuantize" in elem.name: num_fake_quantize += 1 for i in range(elem.get_output_size()): - if "8" in elem.get_output_element_type(i).get_type_name(): + if elem.get_output_element_type(i).get_type_name() in ["i8", "u8"]: num_int8 += 1 - if "4" in elem.get_output_element_type(i).get_type_name(): + if elem.get_output_element_type(i).get_type_name() in ["i4", "u4"]: num_int4 += 1 return num_fake_quantize, num_int8, num_int4