Skip to content

Commit

Permalink
Fixed names of precisions
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Dec 18, 2023
1 parent 96f0af5 commit c55610b
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 33 deletions.
8 changes: 4 additions & 4 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def parse_args_openvino(parser: "ArgumentParser"):
optional_group.add_argument(
"--weight-format",
type=str,
choices=["f32", "f16", "i8", "i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"],
choices=["fp32", "fp16", "int8", "int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"],
default=None,
help=(
"The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, i4_* - for INT4 compressed weights."
"The weight format of the exporting model, e.g. f32 stands for float32 weights, f16 - for float16 weights, i8 - INT8 weights, int4_* - for INT4 compressed weights."
),
)
optional_group.add_argument(
Expand Down Expand Up @@ -121,12 +121,12 @@ def run(self):
logger.warning(
"`--fp16` option is deprecated and will be removed in a future version. Use `--weight-format` instead."
)
self.args.weight_format = "f16"
self.args.weight_format = "fp16"
if self.args.int8:
logger.warning(
"`--int8` option is deprecated and will be removed in a future version. Use `--weight-format` instead."
)
self.args.weight_format = "i8"
self.args.weight_format = "int8"

# TODO : add input shapes
main_export(
Expand Down
10 changes: 5 additions & 5 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def main_export(
Experimental usage: Override the default submodels that are used at the export. This is
especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression.
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
**kwargs_shapes (`Dict`):
Expand All @@ -136,8 +136,8 @@ def main_export(
"""
if (
compression_option is not None
and compression_option != "f16"
and compression_option != "f32"
and compression_option != "fp16"
and compression_option != "fp32"
and not is_nncf_available()
):
raise ImportError(
Expand Down Expand Up @@ -297,7 +297,7 @@ class StoreAttr(object):
num_parameters = model.num_parameters() if not is_stable_diffusion else model.unet.num_parameters()
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
compression_option = "i8"
compression_option = "int8"
logger.info("The model weights will be quantized to int8.")
else:
logger.warning(
Expand Down
26 changes: 13 additions & 13 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@


def _save_model(model, path: str, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None):
if compression_option is not None and compression_option != "f16" and compression_option != "f32":
if compression_option is not None and compression_option != "fp16" and compression_option != "fp32":
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
Expand All @@ -64,31 +64,31 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, comp
import nncf

COMPRESSION_OPTIONS = {
"i8": {"mode": nncf.CompressWeightsMode.INT8},
"i4_sym_g128": {
"int8": {"mode": nncf.CompressWeightsMode.INT8},
"int4_sym_g128": {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"group_size": 128,
"ratio": compression_ratio,
},
"i4_asym_g128": {
"int4_asym_g128": {
"mode": nncf.CompressWeightsMode.INT4_ASYM,
"group_size": 128,
"ratio": compression_ratio,
},
"i4_sym_g64": {
"int4_sym_g64": {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"group_size": 64,
"ratio": compression_ratio,
},
"i4_asym_g64": {
"int4_asym_g64": {
"mode": nncf.CompressWeightsMode.INT4_ASYM,
"group_size": 64,
"ratio": compression_ratio,
},
}
model = nncf.compress_weights(model, **COMPRESSION_OPTIONS[compression_option])

compress_to_fp16 = compression_option == "f16"
compress_to_fp16 = compression_option == "fp16"
save_model(model, path, compress_to_fp16)


Expand Down Expand Up @@ -119,8 +119,8 @@ def export(
The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point.
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
input_shapes (`Optional[Dict]`, defaults to `None`):
Expand Down Expand Up @@ -230,8 +230,8 @@ def export_pytorch_via_onnx(
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
Additional kwargs for model export.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point.
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
Expand Down Expand Up @@ -445,8 +445,8 @@ def export_models(
input_shapes (Optional[Dict], optional, Defaults to None):
If specified, allows to use specific shapes for the example input provided to the exporter.
compression_option (`Optional[str]`, defaults to `None`):
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `i4_sym_g128` - INT4 symmetric weights w/ group size 128, `i4_asym_g128` - as previous but asymmetric w/ zero-point,
`i4_sym_g64` - INT4 symmetric weights w/ group size 64, "i4_asym_g64" - as previous but asymmetric w/ zero-point.
The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point,
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point.
compression_ratio (`Optional[int]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
model_kwargs (Optional[Dict[str, Any]], optional):
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
compression_option="i8" if load_in_8bit else None,
compression_option="int8" if load_in_8bit else None,
)

config.save_pretrained(save_dir_path)
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
compression_option="i8" if load_in_8bit else None,
compression_option="int8" if load_in_8bit else None,
)

config.is_decoder = True
Expand Down
10 changes: 5 additions & 5 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@


COMPRESSION_OPTIONS = {
"i8": {"mode": nncf.CompressWeightsMode.INT8},
"i4_sym_g128": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128},
"i4_asym_g128": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128},
"i4_sym_g64": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64},
"i4_asym_g64": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64},
"int8": {"mode": nncf.CompressWeightsMode.INT8},
"int4_sym_g128": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128},
"int4_asym_g128": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128},
"int4_sym_g64": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64},
"int4_asym_g64": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64},
}

register_module(ignored_algorithms=[])(Conv1D)
Expand Down
8 changes: 4 additions & 4 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class OVCLIExportTestCase(unittest.TestCase):

SUPPORTED_4BIT_ARCHITECTURES = (("text-generation-with-past", "opt125m"),)

SUPPORTED_4BIT_OPTIONS = ["i4_sym_g128", "i4_asym_g128", "i4_sym_g64", "i4_asym_g64"]
SUPPORTED_4BIT_OPTIONS = ["int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"]

TEST_4BIT_CONFIGURATONS = []
for arch in SUPPORTED_4BIT_ARCHITECTURES:
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_exporters_cli(self, task: str, model_type: str):
def test_exporters_cli_fp16(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights f16 {tmpdir}",
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format fp16 {tmpdir}",
shell=True,
check=True,
)
Expand All @@ -113,7 +113,7 @@ def test_exporters_cli_fp16(self, task: str, model_type: str):
def test_exporters_cli_int8(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights i8 {tmpdir}",
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format int8 {tmpdir}",
shell=True,
check=True,
)
Expand All @@ -139,7 +139,7 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
def test_exporters_cli_int4(self, task: str, model_type: str, option: str):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --compress-weights {option} {tmpdir}",
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}",
shell=True,
check=True,
)
Expand Down

0 comments on commit c55610b

Please sign in to comment.