Skip to content

Commit

Permalink
fix ov config for fp32 models
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Feb 21, 2024
1 parent 329ef26 commit de9b118
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 9 additions & 7 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,13 @@ def run(self):
)
self.args.weight_format = "int8"

weight_format = self.args.weight_format or "fp32"

ov_config = None
if self.args.weight_format in {"fp16", "fp32"}:
ov_config = OVConfig(dtype=self.args.weight_format)
if weight_format in {"fp16", "fp32"}:
ov_config = OVConfig(dtype=weight_format)
else:
is_int8 = self.args.weight_format == "int8"
is_int8 = weight_format == "int8"

# For int4 quantization if not parameter is provided, then use the default config if exist
if (
Expand All @@ -180,12 +182,12 @@ def run(self):
"group_size": -1 if is_int8 else self.args.group_size,
}

if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
if weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
logger.warning(
f"--weight-format {self.args.weight_format} is deprecated, possible choices are fp32, fp16, int8, int4"
f"--weight-format {weight_format} is deprecated, possible choices are fp32, fp16, int8, int4"
)
quantization_config["sym"] = "asym" not in self.args.weight_format
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
quantization_config["sym"] = "asym" not in weight_format
quantization_config["group_size"] = 128 if "128" in weight_format else 64
ov_config = OVConfig(quantization_config=quantization_config)

# TODO : add input shapes
Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@

def _save_model(model, path: str, ov_config: Optional["OVConfig"] = None):
compress_to_fp16 = False

if ov_config is not None:
if ov_config.quantization_config is not None:
if ov_config.quantization_config:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
Expand Down

0 comments on commit de9b118

Please sign in to comment.