From de9b1180fd74af417a55175b9edc6fdf725fba2c Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 21 Feb 2024 12:21:13 +0100 Subject: [PATCH] fix ov config for fp32 models --- optimum/commands/export/openvino.py | 16 +++++++++------- optimum/exporters/openvino/convert.py | 3 ++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index a0246be7d1..255e2a7e13 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -157,11 +157,13 @@ def run(self): ) self.args.weight_format = "int8" + weight_format = self.args.weight_format or "fp32" + ov_config = None - if self.args.weight_format in {"fp16", "fp32"}: - ov_config = OVConfig(dtype=self.args.weight_format) + if weight_format in {"fp16", "fp32"}: + ov_config = OVConfig(dtype=weight_format) else: - is_int8 = self.args.weight_format == "int8" + is_int8 = weight_format == "int8" # For int4 quantization if not parameter is provided, then use the default config if exist if ( @@ -180,12 +182,12 @@ def run(self): "group_size": -1 if is_int8 else self.args.group_size, } - if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}: + if weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}: logger.warning( - f"--weight-format {self.args.weight_format} is deprecated, possible choices are fp32, fp16, int8, int4" + f"--weight-format {weight_format} is deprecated, possible choices are fp32, fp16, int8, int4" ) - quantization_config["sym"] = "asym" not in self.args.weight_format - quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64 + quantization_config["sym"] = "asym" not in weight_format + quantization_config["group_size"] = 128 if "128" in weight_format else 64 ov_config = OVConfig(quantization_config=quantization_config) # TODO : add input shapes diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 2b12f49be7..9dba2ac324 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -77,8 +77,9 @@ def _save_model(model, path: str, ov_config: Optional["OVConfig"] = None): compress_to_fp16 = False + if ov_config is not None: - if ov_config.quantization_config is not None: + if ov_config.quantization_config: if not is_nncf_available(): raise ImportError( "Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"