Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INT4 compression support #469

Merged
merged 18 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,21 @@ def parse_args_openvino(parser: "ArgumentParser"):
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
),
)
optional_group.add_argument("--fp16", action="store_true", help="Compress weights to fp16"),
optional_group.add_argument("--int8", action="store_true", help="Compress weights to int8"),
optional_group.add_argument(
"--weight-format",
type=str,
choices=["f32", "f16", "i8", "i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be in favor of using fp32, fp16, int8 to keep the same format as for transformers and optimum

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also we might want to move the sym/asym from this set of options so that it can also be made available for int8, not sure it's needed though the default asym mode might be enough, let me know what you think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have INT8 symmetric in the new version of NNCF. I am also thinking that we need to reduce the number of available options here and keep only symmetrical because they provide a better accuracy-performance trade-off (varying group size and ratio). @ljaljushkin, please provide your opinion as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

symmetric mode has a better correlation between model size and latency than asymmetric one.
can't say for sure, that varying group size and ratio for symmetric always gives a decent accuracy-performance trade-off.
there are some models when symmetric mode doesn't achieve it.
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be in favor of using fp32, fp16, int8 to keep the same format as for transformers and optimum

I followed the notation of OpenVINO types but I can change.

Copy link
Collaborator

@helena-intel helena-intel Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also thinking that we need to reduce the number of available options here

I would much prefer to be able to use all weight compression options available in NNCF in Optimum. In my experience there are always specific cases where they are useful, and it's not good to have to completely switch frameworks/APIs when you want to use them. Also agreed that we should not overwhelm users - but in my opinion we're not there yet - and it also introduces confusion if there are differences between what's available in NNCF and what's available in optimum.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be in favor of using fp32, fp16, int8 to keep the same format as for transformers and optimum

I followed the notation of OpenVINO types but I can change.

I think that would be easier to keep consistency with other optimum's subpackage.

No strong opinion concerning the symmetric/asymmetric mode, I'm fine with both options

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also thinking that we need to reduce the number of available options here

I would much prefer to be able to use all weight compression options available in NNCF in Optimum. In my experience there are always specific cases where they are useful, and it's not good to have to completely switch frameworks/APIs when you want to use them. Also agreed that we should not overwhelm users - but in my opinion we're not there yet - and it also introduces confusion if there are differences between what's available in NNCF and what's available in optimum.

Thanks, Helena! Understood your concerns but we have experimental schemes in NNCF that are not yet performant in OpenVINO so I am not going to expose them at this point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be in favor of using fp32, fp16, int8 to keep the same format as for transformers and optimum

I followed the notation of OpenVINO types but I can change.

I think that would be easier to keep consistency with other optimum's subpackage.

No strong opinion concerning the symmetric/asymmetric mode, I'm fine with both options

To strong objections, I can align the names of precisions with other parts of HF ecosystem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated names

default=None,
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
help=(
"The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, i4_* - for INT4 compressed weights."
),
)
optional_group.add_argument(
"--ratio",
type=float,
default=0.8,
help="Compression ratio between primary and backup precision (only applicable to INT4 type).",
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
)


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -104,7 +117,7 @@ def run(self):
cache_dir=self.args.cache_dir,
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
fp16=self.args.fp16,
int8=self.args.int8,
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
compression_option=self.args.weights_format,
compression_ratio=self.args.ratio
# **input_shapes,
)
29 changes: 18 additions & 11 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def main_export(
output: Union[str, Path],
task: str = "auto",
device: str = "cpu",
fp16: Optional[bool] = False,
framework: Optional[str] = None,
cache_dir: Optional[str] = None,
trust_remote_code: bool = False,
Expand All @@ -64,7 +63,8 @@ def main_export(
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
int8: Optional[bool] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
**kwargs_shapes,
):
"""
Expand All @@ -85,8 +85,6 @@ def main_export(
use `xxx-with-past` to export the model using past key values in the decoder.
device (`str`, defaults to `"cpu"`):
The device to use to do the export. Defaults to "cpu".
fp16 (`Optional[bool]`, defaults to `"False"`):
Use half precision during the export. PyTorch-only, requires `device="cuda"`.
framework (`Optional[str]`, defaults to `None`):
The framework to use for the ONNX export (`"pt"` or `"tf"`). If not provided, will attempt to automatically detect
the framework for the checkpoint.
Expand Down Expand Up @@ -121,6 +119,11 @@ def main_export(
fn_get_submodels (`Optional[Callable]`, defaults to `None`):
Experimental usage: Override the default submodels that are used at the export. This is
especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand All @@ -131,9 +134,14 @@ def main_export(
>>> main_export("gpt2", output="gpt2_onnx/")
```
"""
if int8 and not is_nncf_available():
if (
compression_option is not None
and compression_option != "f16"
and compression_option != "f32"
and not is_nncf_available()
):
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
f"Compression of the weights to {compression_option} requires nncf, please install it with `pip install nncf`"
)

model_kwargs = model_kwargs or {}
Expand Down Expand Up @@ -285,12 +293,11 @@ class StoreAttr(object):
legacy=False,
)

if int8 is None:
int8 = False
if compression_option is None:
num_parameters = model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters()
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
int8 = True
compression_option = "i8"
logger.info("The model weights will be quantized to int8.")
else:
logger.warning(
Expand Down Expand Up @@ -364,8 +371,8 @@ class StoreAttr(object):
output_names=files_subpaths,
input_shapes=input_shapes,
device=device,
fp16=fp16,
int8=int8,
compression_option=compression_option,
compression_ratio=compression_ratio,
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
model_kwargs=model_kwargs,
)

Expand Down
96 changes: 74 additions & 22 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,41 @@
from transformers.modeling_tf_utils import TFPreTrainedModel


def _save_model(model, path: str, compress_to_fp16=False, load_in_8bit=False):
if load_in_8bit:
def _save_model(model, path: str, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None):
if compression_option is not None and compression_option != "f16" and compression_option != "f32":
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

import nncf

model = nncf.compress_weights(model)
COMPRESSION_OPTIONS = {
"i8": {"mode": nncf.CompressWeightsMode.INT8},
"i4_sym_g128": {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"group_size": 128,
"ratio": compression_ratio,
},
"i4_asym_g128": {
"mode": nncf.CompressWeightsMode.INT4_ASYM,
"group_size": 128,
"ratio": compression_ratio,
},
"i4_sym_g64": {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"group_size": 64,
"ratio": compression_ratio,
},
"i4_asym_g64": {
"mode": nncf.CompressWeightsMode.INT4_ASYM,
"group_size": 64,
"ratio": compression_ratio,
},
}
model = nncf.compress_weights(model, **COMPRESSION_OPTIONS[compression_option])

compress_to_fp16 = compression_option == "f16"
save_model(model, path, compress_to_fp16)


Expand All @@ -75,8 +100,8 @@ def export(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an OpenVINO Intermediate Representation.
Expand All @@ -93,6 +118,11 @@ def export(
device (`str`, *optional*, defaults to `cpu`):
The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the exporter.

Expand All @@ -117,9 +147,9 @@ def export(
output,
device=device,
input_shapes=input_shapes,
compression_option=compression_option,
compression_ratio=compression_ratio,
model_kwargs=model_kwargs,
fp16=fp16,
int8=int8,
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
Expand All @@ -143,6 +173,8 @@ def export_tensorflow(
config: OnnxConfig,
opset: int,
output: Path,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
):
"""
Export the TensorFlow model to OpenVINO format.
Expand All @@ -161,7 +193,9 @@ def export_tensorflow(
onnx_path = Path(output).with_suffix(".onnx")
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
ov_model = convert_model(str(onnx_path))
_save_model(ov_model, output.parent / output, compress_to_fp16=False, load_in_8bit=False)
_save_model(
ov_model, output.parent / output, compression_option=compression_option, compression_ratio=compression_ratio
)
return input_names, output_names, True


Expand All @@ -173,8 +207,8 @@ def export_pytorch_via_onnx(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
):
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export.
Expand All @@ -194,7 +228,12 @@ def export_pytorch_via_onnx(
input_shapes (`optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the exporter.
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
Additional kwargs for model export
Additional kwargs for model export.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).

Returns:
`Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -216,8 +255,8 @@ def export_pytorch_via_onnx(
_save_model(
ov_model,
output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output,
compress_to_fp16=fp16,
load_in_8bit=int8,
compression_option=compression_option,
compression_ratio=compression_ratio,
)
return input_names, output_names, True

Expand All @@ -230,8 +269,8 @@ def export_pytorch(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an OpenVINO Intermediate Representation.
Expand Down Expand Up @@ -342,7 +381,15 @@ def ts_patched_forward(*args, **kwargs):
if patch_model_forward:
model.forward = orig_forward
return export_pytorch_via_onnx(
model, config, opset, output, device, input_shapes, model_kwargs, fp16=fp16, int8=int8
model,
config,
opset,
output,
device,
input_shapes,
model_kwargs,
compression_option=compression_option,
compression_ratio=compression_ratio,
)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
Expand All @@ -364,7 +411,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()
_save_model(ov_model, output, compress_to_fp16=fp16, load_in_8bit=int8)
_save_model(ov_model, output, compression_option=compression_option, compression_ratio=compression_ratio)
clear_class_registry()
del model
gc.collect()
Expand All @@ -381,8 +428,8 @@ def export_models(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
fp16: bool = False,
int8: bool = False,
compression_option: Optional[str] = None,
compression_ratio: Optional[int] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Export the models to OpenVINO IR format
Expand All @@ -397,8 +444,13 @@ def export_models(
export on CUDA devices.
input_shapes (Optional[Dict], optional, Defaults to None):
If specified, allows to use specific shapes for the example input provided to the exporter.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[int]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
model_kwargs (Optional[Dict[str, Any]], optional):
Additional kwargs for model export
Additional kwargs for model export.

Raises:
ValueError: if custom names set not equal of number of models
Expand Down Expand Up @@ -427,8 +479,8 @@ def export_models(
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
fp16=fp16,
int8=int8,
compression_option=compression_option,
compression_ratio=compression_ratio,
)
)

Expand Down
6 changes: 5 additions & 1 deletion optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def _enable_standard_onnx_export_option(self):
# save_onnx_model is defaulted to false so that the final model output is
# in OpenVINO IR to realize performance benefit in OpenVINO runtime.
# True value of save_onnx_model will save a model in onnx format.
if isinstance(self.compression, dict) and self.compression["algorithm"] == "quantization":
if (
isinstance(self.compression, dict)
and "algorithm" in self.compression
and self.compression["algorithm"] == "quantization"
):
self.compression["export_to_onnx_standard_ops"] = self.save_onnx_model
elif isinstance(self.compression, list):
for i, algo_config in enumerate(self.compression):
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
int8=load_in_8bit,
compression_option="i8" if load_in_8bit else None,
)

config.save_pretrained(save_dir_path)
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
int8=load_in_8bit,
compression_option="i8" if load_in_8bit else None,
)

config.is_decoder = True
Expand Down
Loading
Loading