Skip to content

Commit

Permalink
Add section
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jun 24, 2024
1 parent a739930 commit 33dc386
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 48 deletions.
100 changes: 78 additions & 22 deletions docs/source/openvino/export.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,82 @@ To export your model to the [OpenVINO IR](https://docs.openvino.ai/2024/document
optimum-cli export openvino --model gpt2 ov_model/
```

The model argument can either be the model ID of a model hosted on the [Hub](https://huggingface.co/models) or a path to a model hosted locally.
The model argument can either be the model ID of a model hosted on the [Hub](https://huggingface.co/models) or a path to a model hosted locally. For local models, you need to specify the task for which the model should be loaded before export, among the list of the [supported tasks](https://huggingface.co/docs/optimum/main/en/exporters/task_manager).


Check out the help for more options:

```bash
optimum-cli export openvino --help
optimum-cli export openvino --model local_model_dir --task text-generation-with-past ov_model/
```

#### Task
The `-with-past` suffix enable the re-use of the pre-computed key/values hidden-states and is the recommended option, to export the model without, you will need to remove this suffix.

Specifying a --task should not be necessary in most cases when exporting from a model on the Hugging Face Hub.
| With K-V cache | Without K-V cache |
|------------------------------------------|--------------------------------------|
| `text-generation-with-past` | `text-generation` |
| `text2text-generation-with-past` | `text2text-generation` |
| `automatic-speech-recognition-with-past` | `automatic-speech-recognition` |

If the task argument is not provided, it will be automatically inferred.

For local models, you need to specify it among the list of the [supported tasks](https://huggingface.co/docs/optimum/main/en/exporters/task_manager):
Check out the help for more options:

```bash
optimum-cli export openvino --model local_model_dir --task text-generation-with-past ov_model/
```

#### Exporting a model using past keys/values in the decoder

When exporting a decoder model used for generation, it can be useful to encapsulate in the exported model the [reuse of past keys and values](https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/2). This allows to avoid recomputing the same intermediate activations during the generation.

This behavior corresponds to `--task text-geeneration-with-past`, `--task text2text-generation-with-past`, or `--task automatic-speech-recognition-with-past`. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to `optimum-cli export openvino` the task `text2text-generation`, `text-generation` or `automatic-speech-recognition` is required.
optimum-cli export openvino --help

usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--framework {pt,tf}] [--trust-remote-code] [--weight-format {fp32,fp16,int8,int4}]
[--library {transformers,diffusers,timm,sentence_transformers}] [--cache_dir CACHE_DIR] [--pad-token-id PAD_TOKEN_ID] [--ratio RATIO] [--sym]
[--group-size GROUP_SIZE] [--dataset DATASET] [--all-layers] [--awq] [--scale-estimation] [--sensitivity-metric SENSITIVITY_METRIC] [--num-samples NUM_SAMPLES]
[--disable-stateful] [--disable-convert-tokenizer]
output

optional arguments:
-h, --help show this help message and exit

Required arguments:
--model MODEL Model ID on huggingface.co or path on disk to load model from.

output Path indicating the directory where to store the generated OV model.

Optional arguments:
--task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['image-segmentation',
'feature-extraction', 'mask-generation', 'audio-classification', 'conversational', 'stable-diffusion-xl', 'question-answering', 'sentence-similarity', 'text2text-generation',
'masked-im', 'automatic-speech-recognition', 'fill-mask', 'image-to-text', 'text-generation', 'zero-shot-object-detection', 'multiple-choice', 'object-detection', 'stable-
diffusion', 'audio-xvector', 'text-to-audio', 'zero-shot-image-classification', 'token-classification', 'image-classification', 'depth-estimation', 'image-to-image', 'audio-
frame-classification', 'semantic-segmentation', 'text-classification']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder.
--framework {pt,tf} The framework to use for the export. If not provided, will attempt to use the local checkpoint's original framework or what is available in the environment.
--trust-remote-code Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it
will execute on your local machine arbitrary code present in the model repository.
--weight-format {fp32,fp16,int8,int4}
The weight format of the exported model.
--library {transformers,diffusers,timm,sentence_transformers}
The library used to load the model before export. If not provided, will attempt to infer the local checkpoint's library.
--cache_dir CACHE_DIR
The path to a directory in which the downloaded model should be cached if the standard cache should not be used.
--pad-token-id PAD_TOKEN_ID
This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it.
--ratio RATIO A parameter used when applying 4-bit quantization to control the ratio between 4-bit and 8-bit quantization. If set to 0.8, 80% of the layers will be quantized to int4 while
20% will be quantized to int8. This helps to achieve better accuracy at the sacrifice of the model size and inference latency.
--sym Whether to apply symmetric quantization
--group-size GROUP_SIZE
The group size to use for int4 quantization. Recommended value is 128 and -1 will results in per-column quantization.
--dataset DATASET The dataset used for data-aware compression or quantization with NNCF. You can use the one from the list ['wikitext2','c4','c4-new'] for language models or
['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models.
--all-layers Whether embeddings and last MatMul layers should be compressed to INT4. If not provided an weight compression is applied, they are compressed to INT8.
--awq Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but requires additional time for tuning weights on a calibration dataset. To run AWQ,
please also provide a dataset argument. Note: it's possible that there will be no matching patterns in the model to apply AWQ, in such case it will be skipped.
--scale-estimation Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between the original and compressed layers. Providing a dataset is required to run scale
estimation. Please note, that applying scale estimation takes additional memory and time.
--sensitivity-metric SENSITIVITY_METRIC
The sensitivity metric for assigning quantization precision to layers. Can be one of the following: ['weight_quantization_error', 'hessian_input_activation',
'mean_activation_variance', 'max_activation_variance', 'mean_activation_magnitude'].
--num-samples NUM_SAMPLES
The maximum number of samples to take from the dataset for quantization.
--disable-stateful Disable stateful converted models, stateless models will be generated instead. Stateful models are produced by default when this key is not used. In stateful models all kv-cache
inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. If --disable-stateful option is used, it may result in sub-optimal inference
performance. Use it when you intentionally want to use a stateless model, for example, to be compatible with existing OpenVINO native inference code that expects kv-cache inputs
and outputs in the model.
--disable-convert-tokenizer
Do not add converted tokenizer and detokenizer OpenVINO models.
```

#### Quantization

Expand All @@ -62,18 +111,25 @@ Models larger than 1 billion parameters are exported to the OpenVINO format with

</Tip>

Once the model is exported, you can now [load your OpenVINO model](inference)

#### Custom export
Once the model is exported, you can now [load your OpenVINO model](inference) by replacing the `AutoModelForXxx` class with the corresponding `OVModelForXxx` class.

<Tip>
#### When loading your model

You can also load your PyTorch checkpoint and convert it to the OpenVINO format on-the-fly, by setting `export=True` when loading your model.

```python
from optimum.intel import OVModelForCausalLM

model = OVModelForCausalLM.from_pretrained("gpt2", export=True)
model.save_pretrained("ov_model")

```

</Tip>
#### After loading your model

```python
from transfomers import AutoModelForCausalLM
from optimum.exporters.openvino import export_from_model

model = AutoModelForCausalLM.from_pretrained("gpt2")
export_from_model(model, output="ov_model", task="text-generation-with-past")
```
49 changes: 25 additions & 24 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def parse_args_openvino(parser: "ArgumentParser"):
f" {str(TasksManager.get_all_tasks())}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder."
),
)
optional_group.add_argument(
"--cache_dir", type=str, default=HUGGINGFACE_HUB_CACHE, help="Path indicating where to store cache."
)
optional_group.add_argument(
"--framework",
type=str,
Expand All @@ -72,31 +69,40 @@ def parse_args_openvino(parser: "ArgumentParser"):
),
)
optional_group.add_argument(
"--pad-token-id",
type=int,
"--weight-format",
type=str,
choices=["fp32", "fp16", "int8", "int4", "int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"],
default=None,
help=(
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
),
help="he weight format of the exported model.",
)
optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16")
optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8")
optional_group.add_argument(
"--weight-format",
"--library",
type=str,
choices=["fp32", "fp16", "int8", "int4", "int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"],
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
default=None,
help="The library used to laod the model before export. If not provided, will attempt to infer the local checkpoint's library",
)
optional_group.add_argument(
"--cache_dir",
type=str,
default=HUGGINGFACE_HUB_CACHE,
help="The path to a directory in which the downloaded model should be cached if the standard cache should not be used.",
)
optional_group.add_argument(
"--pad-token-id",
type=int,
default=None,
help=(
"The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, int4_* - for INT4 compressed weights."
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
),
)
optional_group.add_argument(
"--ratio",
type=float,
default=None,
help=(
"Compression ratio between primary and backup precision. In the case of INT4, NNCF evaluates layer sensitivity and keeps the most impactful layers in INT8"
"precision (by default 20%% in INT8). This helps to achieve better accuracy after weight compression."
"A parameter used when applying 4-bit quantization to control the ratio between 4-bit and 8-bit quantization. If set to 0.8, 80%% of the layers will be quantized to int4 "
"while 20%% will be quantized to int8. This helps to achieve better accuracy at the sacrifice of the model size and inference latency."
),
)
optional_group.add_argument(
Expand All @@ -117,7 +123,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
default=None,
help=(
"The dataset used for data-aware compression or quantization with NNCF. "
"You can use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs "
"You can use the one from the list ['wikitext2','c4','c4-new'] for language models "
"or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models."
),
)
Expand Down Expand Up @@ -183,20 +189,15 @@ def parse_args_openvino(parser: "ArgumentParser"):
action="store_true",
help="Do not add converted tokenizer and detokenizer OpenVINO models.",
)
#TODO : deprecated
optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16")
optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8")
optional_group.add_argument(
"--convert-tokenizer",
action="store_true",
help="[Deprecated] Add converted tokenizer and detokenizer with OpenVINO Tokenizers.",
)

optional_group.add_argument(
"--library",
type=str,
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
default=None,
help=("The library on the model. If not provided, will attempt to infer the local checkpoint's library"),
)


class OVExportCommand(BaseOptimumCLICommand):
COMMAND = CommandInfo(name="openvino", help="Export PyTorch models to OpenVINO IR.")
Expand Down
4 changes: 2 additions & 2 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
dataset (`str or List[str]`, *optional*):
The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset
in a list of strings or just use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs
in a list of strings or just use the one from the list ['wikitext2','c4','c4-new'] for language models
or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models.
Alternatively, you can provide data objects via `calibration_dataset` argument
of `OVQuantizer.quantize()` method.
Expand Down Expand Up @@ -230,7 +230,7 @@ def post_init(self):
f"If you wish to provide a custom dataset, please use the `OVQuantizer` instead."
)
if self.dataset is not None and isinstance(self.dataset, str):
llm_datasets = ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]
llm_datasets = ["wikitext2", "c4", "c4-new"]
stable_diffusion_datasets = [
"conceptual_captions",
"laion/220k-GPT4Vision-captions-from-LIVIS",
Expand Down

0 comments on commit 33dc386

Please sign in to comment.