From 1ef49b1987b5866b281aa5da138b45aabf039518 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 2 Nov 2023 17:52:26 +0400 Subject: [PATCH 01/14] Added compression options to CLI. Revised load_in_8bit --- optimum/commands/export/openvino.py | 17 +++- optimum/exporters/openvino/__main__.py | 22 ++--- optimum/exporters/openvino/convert.py | 90 ++++++++++++++----- .../intel/openvino/modeling_base_seq2seq.py | 2 +- optimum/intel/openvino/modeling_decoder.py | 2 +- tests/openvino/test_exporters_cli.py | 39 ++++++-- tests/openvino/test_quantization.py | 10 +-- tests/openvino/utils_tests.py | 36 +++++--- 8 files changed, 155 insertions(+), 63 deletions(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 75d8db8f00..e9caa7e4d8 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -68,8 +68,17 @@ def parse_args_openvino(parser: "ArgumentParser"): "This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it." ), ) - optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16"), - optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"), + optional_group.add_argument( + "-c", + "--compress-weights", + type=str, + choices=["f16", "i8", "i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"], + default=None, + help=( + "The weight compression option, e.g. f16 stands for float16 weights, i8 - INT8 weights, i4_* - for INT4 compressed weights." + ), + ) + optional_group.add_argument("--ratio", type=float, default=0.8, help="Compression ratio between primary and backup precision (only relevant to INT4).") class OVExportCommand(BaseOptimumCLICommand): @@ -104,7 +113,7 @@ def run(self): cache_dir=self.args.cache_dir, trust_remote_code=self.args.trust_remote_code, pad_token_id=self.args.pad_token_id, - fp16=self.args.fp16, - int8=self.args.int8, + compression_option=self.args.compress_weights, + compression_ratio=self.args.ratio # **input_shapes, ) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 782aa0bc0d..647bac7993 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -43,7 +43,6 @@ def main_export( output: Union[str, Path], task: str = "auto", device: str = "cpu", - fp16: Optional[bool] = False, framework: Optional[str] = None, cache_dir: Optional[str] = None, trust_remote_code: bool = False, @@ -56,7 +55,8 @@ def main_export( model_kwargs: Optional[Dict[str, Any]] = None, custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, - int8: Optional[bool] = None, + compression_option: Optional[str] = None, + compression_ratio: Optional[float] = None, **kwargs_shapes, ): """ @@ -77,8 +77,6 @@ def main_export( use `xxx-with-past` to export the model using past key values in the decoder. device (`str`, defaults to `"cpu"`): The device to use to do the export. Defaults to "cpu". - fp16 (`Optional[bool]`, defaults to `"False"`): - Use half precision during the export. PyTorch-only, requires `device="cuda"`. framework (`Optional[str]`, defaults to `None`): The framework to use for the ONNX export (`"pt"` or `"tf"`). If not provided, will attempt to automatically detect the framework for the checkpoint. @@ -113,6 +111,11 @@ def main_export( fn_get_submodels (`Optional[Callable]`, defaults to `None`): Experimental usage: Override the default submodels that are used at the export. This is especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. + compression_option (`Optional[str]`, defaults to `None`): + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + compression_ratio (`Optional[float]`, defaults to `None`): + Compression ratio between primary and backup precision (only relevant to INT4). **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -123,7 +126,7 @@ def main_export( >>> main_export("gpt2", output="gpt2_onnx/") ``` """ - if int8 and not is_nncf_available(): + if compression_option is not None and compression_option != "f16" and not is_nncf_available(): raise ImportError( "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" ) @@ -241,12 +244,11 @@ def main_export( onnx_config = onnx_config_constructor(model.config) models_and_onnx_configs = {"model": (model, onnx_config)} - if int8 is None: - int8 = False + if compression_option is None: num_parameters = model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters() if num_parameters >= _MAX_UNCOMPRESSED_SIZE: if is_nncf_available(): - int8 = True + compression_option = "i8" logger.info("The model weights will be quantized to int8.") else: logger.warning( @@ -320,7 +322,7 @@ def main_export( output_names=files_subpaths, input_shapes=input_shapes, device=device, - fp16=fp16, - int8=int8, + compression_option=compression_option, + compression_ratio=compression_ratio, model_kwargs=model_kwargs, ) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 14636f1f77..cc8147ae7e 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -53,16 +53,41 @@ from transformers.modeling_tf_utils import TFPreTrainedModel -def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False): - if load_in_8bit: +def _save_model(model, path: str, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None): + if compression_option is not None and compression_option != "f16": if not is_nncf_available(): raise ImportError( "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" ) import nncf - - model = nncf.compress_weights(model) + + COMPRESSION_OPTIONS = { + "i8": { "mode": nncf.CompressWeightsMode.INT8 }, + "i4_sym_g128": { + "mode": nncf.CompressWeightsMode.INT4_SYM, + "group_size": 128, + "ratio": compression_ratio, + }, + "i4_asym_g128": { + "mode": nncf.CompressWeightsMode.INT4_ASYM, + "group_size": 128, + "ratio": compression_ratio, + }, + "i4_sym_g64": { + "mode": nncf.CompressWeightsMode.INT4_SYM, + "group_size": 64, + "ratio": compression_ratio, + }, + "i4_asym_g64": { + "mode": nncf.CompressWeightsMode.INT4_ASYM, + "group_size": 64, + "ratio": compression_ratio, + }, + } + model = nncf.compress_weights(model, **COMPRESSION_OPTIONS[compression_option]) + + compress_to_fp16 = compression_option == "f16" save_model(model, path, compress_to_fp16) @@ -74,8 +99,8 @@ def export( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, - fp16: bool = False, - int8: bool = False, + compression_option: Optional[str] = None, + compression_ratio: Optional[float] = None, ) -> Tuple[List[str], List[str]]: """ Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation. @@ -92,6 +117,11 @@ def export( device (`str`, *optional*, defaults to `cpu`): The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for export on CUDA devices. + compression_option (`Optional[str]`, defaults to `None`): + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + compression_ratio (`Optional[float]`, defaults to `None`): + Compression ratio between primary and backup precision (only relevant to INT4). input_shapes (`Optional[Dict]`, defaults to `None`): If specified, allows to use specific shapes for the example input provided to the exporter. @@ -116,9 +146,9 @@ def export( output, device=device, input_shapes=input_shapes, + compression_option=compression_option, + compression_ratio=compression_ratio, model_kwargs=model_kwargs, - fp16=fp16, - int8=int8, ) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): @@ -142,6 +172,9 @@ def export_tensorflow( config: OnnxConfig, opset: int, output: Path, + compression_option: Optional[str] = None, + compression_ratio: Optional[float] = None, + ): """ Export the TensorFlow model to OpenVINO format. @@ -160,7 +193,7 @@ def export_tensorflow( onnx_path = Path(output).with_suffix(".onnx") input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path) ov_model = convert_model(str(onnx_path)) - _save_model(ov_model, output.parent / output, compress_to_fp16=False, load_in_8bit=False) + _save_model(ov_model, output.parent / output, compression_option=compression_option, compression_ratio=compression_ratio) return input_names, output_names, True @@ -172,8 +205,8 @@ def export_pytorch_via_onnx( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, - fp16: bool = False, - int8: bool = False, + compression_option: Optional[str] = None, + compression_ratio: Optional[float] = None, ): """ Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export. @@ -193,7 +226,12 @@ def export_pytorch_via_onnx( input_shapes (`optional[Dict]`, defaults to `None`): If specified, allows to use specific shapes for the example input provided to the exporter. model_kwargs (optional[Dict[str, Any]], defaults to `None`): - Additional kwargs for model export + Additional kwargs for model export. + compression_option (`Optional[str]`, defaults to `None`): + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + compression_ratio (`Optional[float]`, defaults to `None`): + Compression ratio between primary and backup precision (only relevant to INT4). Returns: `Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from @@ -215,8 +253,8 @@ def export_pytorch_via_onnx( _save_model( ov_model, output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, - compress_to_fp16=fp16, - load_in_8bit=int8, + compression_option=compression_option, + compression_ratio=compression_ratio ) return input_names, output_names, True @@ -229,8 +267,8 @@ def export_pytorch( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, - fp16: bool = False, - int8: bool = False, + compression_option: Optional[str] = None, + compression_ratio: Optional[float] = None, ) -> Tuple[List[str], List[str]]: """ Exports a PyTorch model to an OpenVINO Intermediate Representation. @@ -326,7 +364,8 @@ def ts_patched_forward(*args, **kwargs): except Exception as ex: logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX") return export_pytorch_via_onnx( - model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8 + model, config, opset, output, device, input_shapes, model_kwargs, compression_option=compression_option, + compression_ratio=compression_ratio ) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} ordered_input_names = list(inputs) @@ -348,7 +387,7 @@ def ts_patched_forward(*args, **kwargs): inp_tensor.get_node().set_partial_shape(static_shape) inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype)) ov_model.validate_nodes_and_infer_types() - _save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=int8) + _save_model(ov_model, output, compression_option=compression_option, compression_ratio=compression_ratio) clear_class_registry() del model gc.collect() @@ -365,8 +404,8 @@ def export_models( device: str = "cpu", input_shapes: Optional[Dict] = None, model_kwargs: Optional[Dict[str, Any]] = None, - fp16: bool = False, - int8: bool = False, + compression_option: Optional[str] = None, + compression_ratio: Optional[int] = None, ) -> Tuple[List[List[str]], List[List[str]]]: """ Export the models to OpenVINO IR format @@ -381,8 +420,13 @@ def export_models( export on CUDA devices. input_shapes (Optional[Dict], optional, Defaults to None): If specified, allows to use specific shapes for the example input provided to the exporter. + compression_option (`Optional[str]`, defaults to `None`): + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + compression_ratio (`Optional[int]`, defaults to `None`): + Compression ratio between primary and backup precision (only relevant to INT4). model_kwargs (Optional[Dict[str, Any]], optional): - Additional kwargs for model export + Additional kwargs for model export. Raises: ValueError: if custom names set not equal of number of models @@ -411,8 +455,8 @@ def export_models( device=device, input_shapes=input_shapes, model_kwargs=model_kwargs, - fp16=fp16, - int8=int8, + compression_option=compression_option, + compression_ratio=compression_ratio, ) ) diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 527adc4347..74a3083452 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -264,7 +264,7 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, - int8=load_in_8bit, + compression_option="i8" if load_in_8bit else None, ) config.save_pretrained(save_dir_path) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 0e018f9f62..f8acd0b5b8 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -268,7 +268,7 @@ class StoreAttr(object): local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, - int8=load_in_8bit, + compression_option="i8" if load_in_8bit else None, ) # Unpatch modules after GPTQ export diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index d2b9960258..f01ccc91ce 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -16,7 +16,7 @@ from tempfile import TemporaryDirectory from parameterized import parameterized -from utils_tests import _ARCHITECTURES_TO_EXPECTED_INT8, MODEL_NAMES, get_num_quantized_nodes +from utils_tests import _ARCHITECTURES_TO_EXPECTED_INT8, MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT4_INT8 from optimum.exporters.openvino.__main__ import main_export from optimum.intel import ( # noqa @@ -56,10 +56,21 @@ class OVCLIExportTestCase(unittest.TestCase): ("stable-diffusion-xl", "stable-diffusion-xl"), ("stable-diffusion-xl", "stable-diffusion-xl-refiner"), ) + + SUPPORTED_4BIT_ARCHITECTURES = ( + ("text-generation-with-past", "opt125m"), + ) + + SUPPORTED_4BIT_OPTIONS = ["i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"] + + TEST_4BIT_CONFIGURATONS = [] + for arch in SUPPORTED_4BIT_ARCHITECTURES: + for option in SUPPORTED_4BIT_OPTIONS: + TEST_4BIT_CONFIGURATONS.append([arch[0], arch[1], option]) - def _openvino_export(self, model_name: str, task: str, fp16: bool = False, int8: bool = False): + def _openvino_export(self, model_name: str, task: str, compression_option: str = None, compression_ratio: float = None): with TemporaryDirectory() as tmpdir: - main_export(model_name_or_path=model_name, output=tmpdir, task=task, fp16=fp16, int8=int8) + main_export(model_name_or_path=model_name, output=tmpdir, task=task, compression_option=compression_option, compression_ratio=compression_ratio) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_export(self, task: str, model_type: str): @@ -80,7 +91,7 @@ def test_exporters_cli(self, task: str, model_type: str): def test_exporters_cli_fp16(self, task: str, model_type: str): with TemporaryDirectory() as tmpdir: subprocess.run( - f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --fp16 {tmpdir}", + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights f16 {tmpdir}", shell=True, check=True, ) @@ -91,7 +102,7 @@ def test_exporters_cli_fp16(self, task: str, model_type: str): def test_exporters_cli_int8(self, task: str, model_type: str): with TemporaryDirectory() as tmpdir: subprocess.run( - f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --int8 {tmpdir}", + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights i8 {tmpdir}", shell=True, check=True, ) @@ -110,5 +121,21 @@ def test_exporters_cli_int8(self, task: str, model_type: str): expected_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] for i, model in enumerate(models): - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_int8[i], num_int8) + + @parameterized.expand(TEST_4BIT_CONFIGURATONS) + def test_exporters_cli_int4(self, task: str, model_type: str, option: str): + with TemporaryDirectory() as tmpdir: + subprocess.run( + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights {option} {tmpdir}", + shell=True, + check=True, + ) + model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {} + model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs) + + expected_int8, expected_int4 = _ARCHITECTURES_TO_EXPECTED_INT4_INT8[model_type] + _, num_int8, num_int4 = get_num_quantized_nodes(model) + self.assertEqual(expected_int8, num_int8) + self.assertEqual(expected_int4, num_int4) \ No newline at end of file diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index c1ec95ea9b..694f4373c7 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -132,7 +132,7 @@ def preprocess_function(examples, tokenizer): model = model_cls.from_pretrained(tmp_dir) - num_fake_quantize, num_int8 = get_num_quantized_nodes(model) + num_fake_quantize, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_fake_quantize, num_fake_quantize) self.assertEqual(expected_int8, num_int8) @@ -176,7 +176,7 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i quantizer.quantize(save_directory=tmp_dir, weights_only=True) model = model_cls.from_pretrained(tmp_dir) - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_pt_int8, num_int8) tokens = tokenizer("This is a sample input", return_tensors="pt") @@ -201,7 +201,7 @@ def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int quantizer.quantize(save_directory=tmp_dir, weights_only=True) model = model_cls.from_pretrained(tmp_dir) - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_ov_int8, num_int8) tokens = tokenizer("This is a sample input", return_tensors="pt") @@ -222,7 +222,7 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] for i, model in enumerate(models): - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_ov_int8[i], num_int8) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) @@ -238,7 +238,7 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type): models = [model] for i, model in enumerate(models): - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(0, num_int8) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 2fa77052eb..0c43d9b455 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -51,6 +51,7 @@ "llama": "fxmarty/tiny-llama-fast-tokenizer", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", + "opt125m": "facebook/opt-125m", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "mistral": "echarlaix/tiny-random-mistral", @@ -97,28 +98,37 @@ _ARCHITECTURES_TO_EXPECTED_INT8 = { - "bert": (34,), - "roberta": (34,), - "albert": (42,), - "vit": (31,), - "blenderbot": (35,), - "gpt2": (22,), - "wav2vec2": (15,), - "distilbert": (33,), - "t5": (32, 52, 42), - "stable-diffusion": (74, 4, 4, 32), - "stable-diffusion-xl": (148, 4, 4, 33), - "stable-diffusion-xl-refiner": (148, 4, 4, 33), + "bert": (68,), + "roberta": (68,), + "albert": (84,), + "vit": (62,), + "blenderbot": (70,), + "gpt2": (44,), + "wav2vec2": (30,), + "distilbert": (66,), + "t5": (64, 104, 84), + "stable-diffusion": (148, 8, 8, 64), + "stable-diffusion-xl": (296, 8, 8, 66), + "stable-diffusion-xl-refiner": (296, 4, 8, 66), +} + + +_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = { + "opt125m": (128, 64) } def get_num_quantized_nodes(ov_model): num_fake_quantize = 0 num_int8 = 0 + num_int4 = 0 for elem in ov_model.model.get_ops(): if "FakeQuantize" in elem.name: num_fake_quantize += 1 for i in range(elem.get_output_size()): if "8" in elem.get_output_element_type(i).get_type_name(): num_int8 += 1 - return num_fake_quantize, num_int8 + if "4" in elem.get_output_element_type(i).get_type_name(): + num_int4 += 1 + return num_fake_quantize, num_int8, num_int4 + From ceb73e4ebc18254dd29432454071b7a9a76760a5 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 2 Nov 2023 19:32:13 +0400 Subject: [PATCH 02/14] Added 4 bit compression into quantizer --- optimum/intel/openvino/configuration.py | 2 +- optimum/intel/openvino/quantization.py | 32 +++++++++++++++++- tests/openvino/test_quantization.py | 43 ++++++++++++++++++++----- tests/openvino/utils_tests.py | 4 +-- 4 files changed, 69 insertions(+), 12 deletions(-) diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index a45ee281f6..57ed772184 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -107,7 +107,7 @@ def _enable_standard_onnx_export_option(self): # save_onnx_model is defaulted to false so that the final model output is # in OpenVINO IR to realize performance benefit in OpenVINO runtime. # True value of save_onnx_model will save a model in onnx format. - if isinstance(self.compression, dict) and self.compression["algorithm"] == "quantization": + if isinstance(self.compression, dict) and "algorithm" in self.compression and self.compression["algorithm"] == "quantization": self.compression["export_to_onnx_standard_ops"] = self.save_onnx_model elif isinstance(self.compression, list): for i, algo_config in enumerate(self.compression): diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index bcc7c2908b..ca6df15664 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -50,6 +50,25 @@ OV_XML_FILE_NAME, ) +COMPRESSION_OPTIONS = { + "i8": { "mode": nncf.CompressWeightsMode.INT8 }, + "i4_sym_g128": { + "mode": nncf.CompressWeightsMode.INT4_SYM, + "group_size": 128 + }, + "i4_asym_g128": { + "mode": nncf.CompressWeightsMode.INT4_ASYM, + "group_size": 128 + }, + "i4_sym_g64": { + "mode": nncf.CompressWeightsMode.INT4_SYM, + "group_size": 64 + }, + "i4_asym_g64": { + "mode": nncf.CompressWeightsMode.INT4_ASYM, + "group_size": 64 + }, +} register_module(ignored_algorithms=[])(Conv1D) @@ -186,6 +205,7 @@ def quantize( data_collator, remove_unused_columns, weights_only, + quantization_config, **kwargs, ) elif isinstance(self.model, OVBaseModel): @@ -212,6 +232,14 @@ def quantize( else: raise TypeError(f"Unsupported model type: {type(self.model)}") + def _get_compression_options(self, config: OVConfig): + options = {} + if config is not None and "type" in config.compression: + options = COMPRESSION_OPTIONS[config.compression["type"]] + if "ratio" in config.compression: + options["ratio"] = config.compression["ratio"] + return options + def _quantize_ovbasemodel( self, calibration_dataset: Dataset, @@ -256,13 +284,15 @@ def _quantize_ovcausallm( data_collator: Optional[DataCollator] = None, remove_unused_columns: bool = True, weights_only: bool = False, + quantization_config: OVConfig = None, **kwargs, ): save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) if weights_only: - self.model.model = nncf.compress_weights(self.model.model) + options = self._get_compression_options(quantization_config) + self.model.model = nncf.compress_weights(self.model.model, **options) self.model.save_pretrained(save_directory) return diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 694f4373c7..d9302591b3 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -93,7 +93,7 @@ def preprocess_function(examples, tokenizer): model = model_cls.from_pretrained(tmp_dir, file_name=file_name) # TODO: uncomment once move to a newer version of NNCF which has some fixes (addmm, baddmm) - # num_fake_quantize, num_int8 = get_num_quantized_nodes(model) + # num_fake_quantize, num_int8, _ = get_num_quantized_nodes(model) # self.assertEqual(expected_fake_quantize, num_fake_quantize) # self.assertEqual(expected_int8, num_int8) @@ -143,9 +143,13 @@ def preprocess_function(examples, tokenizer): class OVWeightCompressionTest(unittest.TestCase): # TODO : add models - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = ( - (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 35), - (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 22), + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = ( + (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 70), + (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 44), + ) + + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ( + (OVModelForCausalLM, "opt125m", 82, 323), ) SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( @@ -162,7 +166,7 @@ class OVWeightCompressionTest(unittest.TestCase): (OVStableDiffusionXLPipeline, "stable-diffusion-xl"), ) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature @@ -187,8 +191,8 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i loaded_config = OVConfig.from_pretrained(tmp_dir) self.assertIsNotNone(loaded_config) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS) - def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) + def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature with tempfile.TemporaryDirectory() as tmp_dir: @@ -207,6 +211,29 @@ def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int tokens = tokenizer("This is a sample input", return_tensors="pt") outputs = model(**tokens) self.assertTrue("logits" in outputs) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS) + def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_int8, expected_int4): + task = model_cls.export_feature + + with tempfile.TemporaryDirectory() as tmp_dir: + model_id = MODEL_NAMES[model_name] + transformers_model = model_cls.from_pretrained(model_id, export=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) + quantizer.quantize(save_directory=tmp_dir, weights_only=True, quantization_config=OVConfig(compression={"type": "i4_sym_g128", "ratio": 0.8})) + model = model_cls.from_pretrained(tmp_dir) + + _, num_int8, num_int4 = get_num_quantized_nodes(model) + self.assertEqual(expected_int8, num_int8) + self.assertEqual(expected_int4, num_int4) + + tokens = tokenizer("This is a sample input", return_tensors="pt") + outputs = model(**tokens) + self.assertTrue("logits" in outputs) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): @@ -349,7 +376,7 @@ def compute_metrics(p): trainer.save_model() model = OVModelForSequenceClassification.from_pretrained(tmp_dir) - num_fake_quantize, num_int8 = get_num_quantized_nodes(model) + num_fake_quantize, num_int8, _ = get_num_quantized_nodes(model) self.assertEqual(expected_fake_quantize, num_fake_quantize) self.assertEqual(expected_int8, num_int8) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 0c43d9b455..cab10dd92e 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -109,12 +109,12 @@ "t5": (64, 104, 84), "stable-diffusion": (148, 8, 8, 64), "stable-diffusion-xl": (296, 8, 8, 66), - "stable-diffusion-xl-refiner": (296, 4, 8, 66), + "stable-diffusion-xl-refiner": (296, 8, 8, 66), } _ARCHITECTURES_TO_EXPECTED_INT4_INT8 = { - "opt125m": (128, 64) + "opt125m": (82, 323) } From 35cef0eaa3c07750e0512f3db96a96f92aa23246 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 2 Nov 2023 19:36:21 +0400 Subject: [PATCH 03/14] Temporary switched to NNCF develop and openvino-nightly --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 6d81b98b2a..cf430e7845 100644 --- a/setup.py +++ b/setup.py @@ -42,8 +42,8 @@ "onnx", "onnxruntime<1.15.0", ], - "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime"], - "nncf": ["nncf>=2.6.0"], + "openvino": ["openvino-nightly", "onnx", "onnxruntime"], + "nncf": ["nncf @ git+https://github.com/openvinotoolkit/nncf.git"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, From b0831505c8327a2082b665c4f6b011c8d8c532b7 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 3 Nov 2023 10:18:55 +0400 Subject: [PATCH 04/14] Fixed tests --- tests/openvino/test_training.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py index 91defbefbb..006a99f2f9 100644 --- a/tests/openvino/test_training.py +++ b/tests/openvino/test_training.py @@ -318,7 +318,7 @@ def tearDown(self): "default_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=DEFAULT_QUANTIZATION_CONFIG, - expected_fake_quantize=43, + expected_fake_quantize=42, expected_int8=32, compression_metrics=["compression_loss"], ), @@ -326,14 +326,14 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=DEFAULT_QUANTIZATION_CONFIG, - expected_fake_quantize=43, + expected_fake_quantize=42, expected_int8=32, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), "customized_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=70, + expected_fake_quantize=69, expected_int8=35, compression_metrics=["compression_loss"], ), @@ -341,7 +341,7 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=70, + expected_fake_quantize=69, expected_int8=35, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), @@ -361,7 +361,7 @@ def tearDown(self): "default_quantization,structured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=43, + expected_fake_quantize=42, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -369,7 +369,7 @@ def tearDown(self): "customized_quantization,structured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=70, + expected_fake_quantize=69, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -378,7 +378,7 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=43, + expected_fake_quantize=42, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], @@ -387,7 +387,7 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=70, + expected_fake_quantize=69, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], @@ -408,7 +408,7 @@ def tearDown(self): "default_quantization,unstructured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=43, + expected_fake_quantize=42, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -416,7 +416,7 @@ def tearDown(self): "customized_quantization,unstructured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=70, + expected_fake_quantize=69, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -425,7 +425,7 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=43, + expected_fake_quantize=42, expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], @@ -434,7 +434,7 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=70, + expected_fake_quantize=69, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], From 320e94eb976fd9b9e7406e7d7e70c414e0c822ca Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 3 Nov 2023 10:19:21 +0400 Subject: [PATCH 05/14] Style --- optimum/commands/export/openvino.py | 7 +++++- optimum/exporters/openvino/__main__.py | 2 +- optimum/exporters/openvino/convert.py | 30 +++++++++++++++--------- optimum/intel/openvino/configuration.py | 6 ++++- optimum/intel/openvino/quantization.py | 23 +++++------------- tests/openvino/test_exporters_cli.py | 31 +++++++++++++++++-------- tests/openvino/test_quantization.py | 14 ++++++----- tests/openvino/utils_tests.py | 5 +--- 8 files changed, 67 insertions(+), 51 deletions(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index e9caa7e4d8..c63ba7887a 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -78,7 +78,12 @@ def parse_args_openvino(parser: "ArgumentParser"): "The weight compression option, e.g. f16 stands for float16 weights, i8 - INT8 weights, i4_* - for INT4 compressed weights." ), ) - optional_group.add_argument("--ratio", type=float, default=0.8, help="Compression ratio between primary and backup precision (only relevant to INT4).") + optional_group.add_argument( + "--ratio", + type=float, + default=0.8, + help="Compression ratio between primary and backup precision (only relevant to INT4).", + ) class OVExportCommand(BaseOptimumCLICommand): diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 647bac7993..076f0e3896 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -112,7 +112,7 @@ def main_export( Experimental usage: Override the default submodels that are used at the export. This is especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index cc8147ae7e..5629861770 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -61,9 +61,9 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, comp ) import nncf - + COMPRESSION_OPTIONS = { - "i8": { "mode": nncf.CompressWeightsMode.INT8 }, + "i8": {"mode": nncf.CompressWeightsMode.INT8}, "i4_sym_g128": { "mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128, @@ -86,7 +86,7 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, comp }, } model = nncf.compress_weights(model, **COMPRESSION_OPTIONS[compression_option]) - + compress_to_fp16 = compression_option == "f16" save_model(model, path, compress_to_fp16) @@ -118,7 +118,7 @@ def export( The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for export on CUDA devices. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). @@ -174,7 +174,6 @@ def export_tensorflow( output: Path, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None, - ): """ Export the TensorFlow model to OpenVINO format. @@ -193,7 +192,9 @@ def export_tensorflow( onnx_path = Path(output).with_suffix(".onnx") input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path) ov_model = convert_model(str(onnx_path)) - _save_model(ov_model, output.parent / output, compression_option=compression_option, compression_ratio=compression_ratio) + _save_model( + ov_model, output.parent / output, compression_option=compression_option, compression_ratio=compression_ratio + ) return input_names, output_names, True @@ -228,7 +229,7 @@ def export_pytorch_via_onnx( model_kwargs (optional[Dict[str, Any]], defaults to `None`): Additional kwargs for model export. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). @@ -254,7 +255,7 @@ def export_pytorch_via_onnx( ov_model, output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, compression_option=compression_option, - compression_ratio=compression_ratio + compression_ratio=compression_ratio, ) return input_names, output_names, True @@ -364,8 +365,15 @@ def ts_patched_forward(*args, **kwargs): except Exception as ex: logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX") return export_pytorch_via_onnx( - model, config, opset, output, device, input_shapes, model_kwargs, compression_option=compression_option, - compression_ratio=compression_ratio + model, + config, + opset, + output, + device, + input_shapes, + model_kwargs, + compression_option=compression_option, + compression_ratio=compression_ratio, ) ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs} ordered_input_names = list(inputs) @@ -421,7 +429,7 @@ def export_models( input_shapes (Optional[Dict], optional, Defaults to None): If specified, allows to use specific shapes for the example input provided to the exporter. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. compression_ratio (`Optional[int]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index 57ed772184..37928289e4 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -107,7 +107,11 @@ def _enable_standard_onnx_export_option(self): # save_onnx_model is defaulted to false so that the final model output is # in OpenVINO IR to realize performance benefit in OpenVINO runtime. # True value of save_onnx_model will save a model in onnx format. - if isinstance(self.compression, dict) and "algorithm" in self.compression and self.compression["algorithm"] == "quantization": + if ( + isinstance(self.compression, dict) + and "algorithm" in self.compression + and self.compression["algorithm"] == "quantization" + ): self.compression["export_to_onnx_standard_ops"] = self.save_onnx_model elif isinstance(self.compression, list): for i, algo_config in enumerate(self.compression): diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index ca6df15664..c05a98bb3e 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -50,24 +50,13 @@ OV_XML_FILE_NAME, ) + COMPRESSION_OPTIONS = { - "i8": { "mode": nncf.CompressWeightsMode.INT8 }, - "i4_sym_g128": { - "mode": nncf.CompressWeightsMode.INT4_SYM, - "group_size": 128 - }, - "i4_asym_g128": { - "mode": nncf.CompressWeightsMode.INT4_ASYM, - "group_size": 128 - }, - "i4_sym_g64": { - "mode": nncf.CompressWeightsMode.INT4_SYM, - "group_size": 64 - }, - "i4_asym_g64": { - "mode": nncf.CompressWeightsMode.INT4_ASYM, - "group_size": 64 - }, + "i8": {"mode": nncf.CompressWeightsMode.INT8}, + "i4_sym_g128": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128}, + "i4_asym_g128": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128}, + "i4_sym_g64": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64}, + "i4_asym_g64": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64}, } register_module(ignored_algorithms=[])(Conv1D) diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index f01ccc91ce..917a1ccdb6 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -16,7 +16,12 @@ from tempfile import TemporaryDirectory from parameterized import parameterized -from utils_tests import _ARCHITECTURES_TO_EXPECTED_INT8, MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT4_INT8 +from utils_tests import ( + _ARCHITECTURES_TO_EXPECTED_INT4_INT8, + _ARCHITECTURES_TO_EXPECTED_INT8, + MODEL_NAMES, + get_num_quantized_nodes, +) from optimum.exporters.openvino.__main__ import main_export from optimum.intel import ( # noqa @@ -56,21 +61,27 @@ class OVCLIExportTestCase(unittest.TestCase): ("stable-diffusion-xl", "stable-diffusion-xl"), ("stable-diffusion-xl", "stable-diffusion-xl-refiner"), ) - - SUPPORTED_4BIT_ARCHITECTURES = ( - ("text-generation-with-past", "opt125m"), - ) - + + SUPPORTED_4BIT_ARCHITECTURES = (("text-generation-with-past", "opt125m"),) + SUPPORTED_4BIT_OPTIONS = ["i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"] - + TEST_4BIT_CONFIGURATONS = [] for arch in SUPPORTED_4BIT_ARCHITECTURES: for option in SUPPORTED_4BIT_OPTIONS: TEST_4BIT_CONFIGURATONS.append([arch[0], arch[1], option]) - def _openvino_export(self, model_name: str, task: str, compression_option: str = None, compression_ratio: float = None): + def _openvino_export( + self, model_name: str, task: str, compression_option: str = None, compression_ratio: float = None + ): with TemporaryDirectory() as tmpdir: - main_export(model_name_or_path=model_name, output=tmpdir, task=task, compression_option=compression_option, compression_ratio=compression_ratio) + main_export( + model_name_or_path=model_name, + output=tmpdir, + task=task, + compression_option=compression_option, + compression_ratio=compression_ratio, + ) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_export(self, task: str, model_type: str): @@ -138,4 +149,4 @@ def test_exporters_cli_int4(self, task: str, model_type: str, option: str): expected_int8, expected_int4 = _ARCHITECTURES_TO_EXPECTED_INT4_INT8[model_type] _, num_int8, num_int4 = get_num_quantized_nodes(model) self.assertEqual(expected_int8, num_int8) - self.assertEqual(expected_int4, num_int4) \ No newline at end of file + self.assertEqual(expected_int4, num_int4) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index d9302591b3..82535609e7 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -147,10 +147,8 @@ class OVWeightCompressionTest(unittest.TestCase): (OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 70), (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 44), ) - - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ( - (OVModelForCausalLM, "opt125m", 82, 323), - ) + + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 82, 323),) SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( (OVModelForCausalLM, "gpt2"), @@ -211,7 +209,7 @@ def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_p tokens = tokenizer("This is a sample input", return_tensors="pt") outputs = model(**tokens) self.assertTrue("logits" in outputs) - + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS) def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_int8, expected_int4): task = model_cls.export_feature @@ -224,7 +222,11 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i tokenizer.pad_token = tokenizer.eos_token quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) - quantizer.quantize(save_directory=tmp_dir, weights_only=True, quantization_config=OVConfig(compression={"type": "i4_sym_g128", "ratio": 0.8})) + quantizer.quantize( + save_directory=tmp_dir, + weights_only=True, + quantization_config=OVConfig(compression={"type": "i4_sym_g128", "ratio": 0.8}), + ) model = model_cls.from_pretrained(tmp_dir) _, num_int8, num_int4 = get_num_quantized_nodes(model) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index cab10dd92e..2e0362f2bf 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -113,9 +113,7 @@ } -_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = { - "opt125m": (82, 323) -} +_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (82, 323)} def get_num_quantized_nodes(ov_model): @@ -131,4 +129,3 @@ def get_num_quantized_nodes(ov_model): if "4" in elem.get_output_element_type(i).get_type_name(): num_int4 += 1 return num_fake_quantize, num_int8, num_int4 - From e32328054a746c5748fd9e5a69df66416a10caa9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlov Date: Thu, 9 Nov 2023 16:35:02 +0400 Subject: [PATCH 06/14] Update optimum/exporters/openvino/__main__.py Co-authored-by: Nico Galoppo --- optimum/exporters/openvino/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 56e8f050be..ce64f27559 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -128,7 +128,7 @@ def main_export( """ if compression_option is not None and compression_option != "f16" and not is_nncf_available(): raise ImportError( - "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" + f"Compression of the weights to {compression_option} requires nncf, please install it with `pip install nncf`" ) model_kwargs = model_kwargs or {} From d878453215c0eff3c4bf83cbd473833937ceec70 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 Dec 2023 13:16:33 +0400 Subject: [PATCH 07/14] Added FP32 option for weights data type --- optimum/commands/export/openvino.py | 11 +++++------ optimum/exporters/openvino/__main__.py | 4 ++-- optimum/exporters/openvino/convert.py | 2 +- setup.py | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index c63ba7887a..25626e20aa 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -69,20 +69,19 @@ def parse_args_openvino(parser: "ArgumentParser"): ), ) optional_group.add_argument( - "-c", - "--compress-weights", + "--weight-format", type=str, - choices=["f16", "i8", "i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"], + choices=["f32", "f16", "i8", "i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"], default=None, help=( - "The weight compression option, e.g. f16 stands for float16 weights, i8 - INT8 weights, i4_* - for INT4 compressed weights." + "The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, i4_* - for INT4 compressed weights." ), ) optional_group.add_argument( "--ratio", type=float, default=0.8, - help="Compression ratio between primary and backup precision (only relevant to INT4).", + help="Compression ratio between primary and backup precision (only applicable to INT4 type).", ) @@ -118,7 +117,7 @@ def run(self): cache_dir=self.args.cache_dir, trust_remote_code=self.args.trust_remote_code, pad_token_id=self.args.pad_token_id, - compression_option=self.args.compress_weights, + compression_option=self.args.weights_format, compression_ratio=self.args.ratio # **input_shapes, ) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 1ab2ed4b00..30f0c17038 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -121,7 +121,7 @@ def main_export( especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. compression_option (`Optional[str]`, defaults to `None`): The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, - `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). **kwargs_shapes (`Dict`): @@ -134,7 +134,7 @@ def main_export( >>> main_export("gpt2", output="gpt2_onnx/") ``` """ - if compression_option is not None and compression_option != "f16" and not is_nncf_available(): + if compression_option is not None and compression_option != "f16" and compression_option != "f32" and not is_nncf_available(): raise ImportError( f"Compression of the weights to {compression_option} requires nncf, please install it with `pip install nncf`" ) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 18885bfa22..449e1264b8 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -55,7 +55,7 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None): - if compression_option is not None and compression_option != "f16": + if compression_option is not None and compression_option != "f16" and compression_option != "f32": if not is_nncf_available(): raise ImportError( "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" diff --git a/setup.py b/setup.py index 14668bb655..2863a620b5 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ "onnxruntime<1.15.0", "transformers>=4.33.0", ], - "openvino": ["openvino-nightly", "onnx", "onnxruntime", "transformers>=4.33.0"], + "openvino": ["openvino", "onnx", "onnxruntime", "transformers>=4.33.0"], "nncf": ["nncf @ git+https://github.com/openvinotoolkit/nncf.git"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], From 7f3b7cf6d52f14e191979cb3d96b55a616afae59 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 Dec 2023 13:40:19 +0400 Subject: [PATCH 08/14] Style --- optimum/exporters/openvino/__main__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 30f0c17038..542199fa4f 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -134,7 +134,12 @@ def main_export( >>> main_export("gpt2", output="gpt2_onnx/") ``` """ - if compression_option is not None and compression_option != "f16" and compression_option != "f32" and not is_nncf_available(): + if ( + compression_option is not None + and compression_option != "f16" + and compression_option != "f32" + and not is_nncf_available() + ): raise ImportError( f"Compression of the weights to {compression_option} requires nncf, please install it with `pip install nncf`" ) From 4c87f032b0f4db612be9aa6ab599e783cb70aa70 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 Dec 2023 14:28:10 +0400 Subject: [PATCH 09/14] Fixed issue --- optimum/commands/export/openvino.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 25626e20aa..088194740c 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -117,7 +117,7 @@ def run(self): cache_dir=self.args.cache_dir, trust_remote_code=self.args.trust_remote_code, pad_token_id=self.args.pad_token_id, - compression_option=self.args.weights_format, + compression_option=self.args.weight_format, compression_ratio=self.args.ratio # **input_shapes, ) From effb7440272ecaecd14e3bc3a537452e6e025cb3 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 Dec 2023 14:59:39 +0400 Subject: [PATCH 10/14] Fixed setup.py --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 2863a620b5..7bc51f9c72 100644 --- a/setup.py +++ b/setup.py @@ -43,8 +43,8 @@ "onnxruntime<1.15.0", "transformers>=4.33.0", ], - "openvino": ["openvino", "onnx", "onnxruntime", "transformers>=4.33.0"], - "nncf": ["nncf @ git+https://github.com/openvinotoolkit/nncf.git"], + "openvino": ["openvino>=2023.2", "onnx", "onnxruntime", "transformers>=4.33.0"], + "nncf": ["nncf>=2.7.0"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, From 96f0af545222a8c5d309b4ccd1394bc0cc7f79fa Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 Dec 2023 16:42:57 +0400 Subject: [PATCH 11/14] Applied some comments --- optimum/commands/export/openvino.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 088194740c..55aacd3679 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -13,6 +13,7 @@ # limitations under the License. """Defines the command line for the export with OpenVINO.""" +import logging import sys from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -21,6 +22,9 @@ from ..base import BaseOptimumCLICommand, CommandInfo +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from argparse import ArgumentParser, Namespace, _SubParsersAction @@ -68,6 +72,8 @@ def parse_args_openvino(parser: "ArgumentParser"): "This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it." ), ) + optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16") + optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8") optional_group.add_argument( "--weight-format", type=str, @@ -81,7 +87,10 @@ def parse_args_openvino(parser: "ArgumentParser"): "--ratio", type=float, default=0.8, - help="Compression ratio between primary and backup precision (only applicable to INT4 type).", + help=( + "Compression ratio between primary and backup precision. In the case of INT4, NNCF evaluates layer sensitivity and keeps the most impactful layers in INT8" + "precision (by default 20% in INT8). This helps to achieve better accuracy after weight quantization." + ), ) @@ -108,6 +117,17 @@ def parse_args(parser: "ArgumentParser"): def run(self): from ...exporters.openvino.__main__ import main_export + if self.args.fp16: + logger.warning( + "`--fp16` option is deprecated and will be removed in a future version. Use `--weight-format` instead." + ) + self.args.weight_format = "f16" + if self.args.int8: + logger.warning( + "`--int8` option is deprecated and will be removed in a future version. Use `--weight-format` instead." + ) + self.args.weight_format = "i8" + # TODO : add input shapes main_export( model_name_or_path=self.args.model, From c55610b22790e9c50f94e6e1e4024177316cb709 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 18 Dec 2023 11:39:46 +0400 Subject: [PATCH 12/14] Fixed names of precisions --- optimum/commands/export/openvino.py | 8 +++--- optimum/exporters/openvino/__main__.py | 10 +++---- optimum/exporters/openvino/convert.py | 26 +++++++++---------- .../intel/openvino/modeling_base_seq2seq.py | 2 +- optimum/intel/openvino/modeling_decoder.py | 2 +- optimum/intel/openvino/quantization.py | 10 +++---- tests/openvino/test_exporters_cli.py | 8 +++--- 7 files changed, 33 insertions(+), 33 deletions(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 55aacd3679..95ecea1213 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -77,10 +77,10 @@ def parse_args_openvino(parser: "ArgumentParser"): optional_group.add_argument( "--weight-format", type=str, - choices=["f32", "f16", "i8", "i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"], + choices=["fp32", "fp16", "int8", "int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"], default=None, help=( - "The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, i4_* - for INT4 compressed weights." + "The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, int4_* - for INT4 compressed weights." ), ) optional_group.add_argument( @@ -121,12 +121,12 @@ def run(self): logger.warning( "`--fp16` option is deprecated and will be removed in a future version. Use `--weight-format` instead." ) - self.args.weight_format = "f16" + self.args.weight_format = "fp16" if self.args.int8: logger.warning( "`--int8` option is deprecated and will be removed in a future version. Use `--weight-format` instead." ) - self.args.weight_format = "i8" + self.args.weight_format = "int8" # TODO : add input shapes main_export( diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 542199fa4f..54fe1193e5 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -120,8 +120,8 @@ def main_export( Experimental usage: Override the default submodels that are used at the export. This is especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, - `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression. + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point, + `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). **kwargs_shapes (`Dict`): @@ -136,8 +136,8 @@ def main_export( """ if ( compression_option is not None - and compression_option != "f16" - and compression_option != "f32" + and compression_option != "fp16" + and compression_option != "fp32" and not is_nncf_available() ): raise ImportError( @@ -297,7 +297,7 @@ class StoreAttr(object): num_parameters = model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters() if num_parameters >= _MAX_UNCOMPRESSED_SIZE: if is_nncf_available(): - compression_option = "i8" + compression_option = "int8" logger.info("The model weights will be quantized to int8.") else: logger.warning( diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 449e1264b8..56c5a10e5d 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -55,7 +55,7 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None): - if compression_option is not None and compression_option != "f16" and compression_option != "f32": + if compression_option is not None and compression_option != "fp16" and compression_option != "fp32": if not is_nncf_available(): raise ImportError( "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`" @@ -64,23 +64,23 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, comp import nncf COMPRESSION_OPTIONS = { - "i8": {"mode": nncf.CompressWeightsMode.INT8}, - "i4_sym_g128": { + "int8": {"mode": nncf.CompressWeightsMode.INT8}, + "int4_sym_g128": { "mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128, "ratio": compression_ratio, }, - "i4_asym_g128": { + "int4_asym_g128": { "mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128, "ratio": compression_ratio, }, - "i4_sym_g64": { + "int4_sym_g64": { "mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64, "ratio": compression_ratio, }, - "i4_asym_g64": { + "int4_asym_g64": { "mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64, "ratio": compression_ratio, @@ -88,7 +88,7 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, comp } model = nncf.compress_weights(model, **COMPRESSION_OPTIONS[compression_option]) - compress_to_fp16 = compression_option == "f16" + compress_to_fp16 = compression_option == "fp16" save_model(model, path, compress_to_fp16) @@ -119,8 +119,8 @@ def export( The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for export on CUDA devices. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, - `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point, + `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). input_shapes (`Optional[Dict]`, defaults to `None`): @@ -230,8 +230,8 @@ def export_pytorch_via_onnx( model_kwargs (optional[Dict[str, Any]], defaults to `None`): Additional kwargs for model export. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, - `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point, + `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point. compression_ratio (`Optional[float]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). @@ -445,8 +445,8 @@ def export_models( input_shapes (Optional[Dict], optional, Defaults to None): If specified, allows to use specific shapes for the example input provided to the exporter. compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point, - `i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point. + The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point, + `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point. compression_ratio (`Optional[int]`, defaults to `None`): Compression ratio between primary and backup precision (only relevant to INT4). model_kwargs (Optional[Dict[str, Any]], optional): diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 9d95c2858e..3471c6f954 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -262,7 +262,7 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, - compression_option="i8" if load_in_8bit else None, + compression_option="int8" if load_in_8bit else None, ) config.save_pretrained(save_dir_path) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 9e9797393e..8147cc74e8 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -239,7 +239,7 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, - compression_option="i8" if load_in_8bit else None, + compression_option="int8" if load_in_8bit else None, ) config.is_decoder = True diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 86e4f8a934..acdfb4a324 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -51,11 +51,11 @@ COMPRESSION_OPTIONS = { - "i8": {"mode": nncf.CompressWeightsMode.INT8}, - "i4_sym_g128": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128}, - "i4_asym_g128": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128}, - "i4_sym_g64": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64}, - "i4_asym_g64": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64}, + "int8": {"mode": nncf.CompressWeightsMode.INT8}, + "int4_sym_g128": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128}, + "int4_asym_g128": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128}, + "int4_sym_g64": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64}, + "int4_asym_g64": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64}, } register_module(ignored_algorithms=[])(Conv1D) diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 917a1ccdb6..b90490d610 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -64,7 +64,7 @@ class OVCLIExportTestCase(unittest.TestCase): SUPPORTED_4BIT_ARCHITECTURES = (("text-generation-with-past", "opt125m"),) - SUPPORTED_4BIT_OPTIONS = ["i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"] + SUPPORTED_4BIT_OPTIONS = ["int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"] TEST_4BIT_CONFIGURATONS = [] for arch in SUPPORTED_4BIT_ARCHITECTURES: @@ -102,7 +102,7 @@ def test_exporters_cli(self, task: str, model_type: str): def test_exporters_cli_fp16(self, task: str, model_type: str): with TemporaryDirectory() as tmpdir: subprocess.run( - f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights f16 {tmpdir}", + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format fp16 {tmpdir}", shell=True, check=True, ) @@ -113,7 +113,7 @@ def test_exporters_cli_fp16(self, task: str, model_type: str): def test_exporters_cli_int8(self, task: str, model_type: str): with TemporaryDirectory() as tmpdir: subprocess.run( - f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights i8 {tmpdir}", + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format int8 {tmpdir}", shell=True, check=True, ) @@ -139,7 +139,7 @@ def test_exporters_cli_int8(self, task: str, model_type: str): def test_exporters_cli_int4(self, task: str, model_type: str, option: str): with TemporaryDirectory() as tmpdir: subprocess.run( - f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights {option} {tmpdir}", + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}", shell=True, check=True, ) From d074c198a1fbcf987761881ad270b1fba2498de4 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 18 Dec 2023 14:35:36 +0400 Subject: [PATCH 13/14] Fixed test --- tests/openvino/test_quantization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index a1bd848fe5..c3378c08e6 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -150,7 +150,7 @@ class OVWeightCompressionTest(unittest.TestCase): (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 45, 44), ) - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 82, 323),) + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 82, 295),) SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( (OVModelForCausalLM, "gpt2"), @@ -227,7 +227,7 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i quantizer.quantize( save_directory=tmp_dir, weights_only=True, - quantization_config=OVConfig(compression={"type": "i4_sym_g128", "ratio": 0.8}), + quantization_config=OVConfig(compression={"type": "int4_sym_g128", "ratio": 0.8}), ) model = model_cls.from_pretrained(tmp_dir) From 1871329882d74802af45028e883ac3bc1879099a Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Mon, 18 Dec 2023 13:49:38 +0100 Subject: [PATCH 14/14] Update tests/openvino/utils_tests.py --- tests/openvino/utils_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index efbe18f4e1..6cfeb29bb4 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -116,7 +116,7 @@ } -_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (82, 323)} +_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (82, 295)} def get_num_quantized_nodes(ov_model):