Skip to content

Commit

Permalink
[OV] Move data-driven quantization after model export for text-genera…
Browse files Browse the repository at this point in the history
…tion models (#721)

* Add quantization with dataset after model export for text-generation models

* Tweak AWQ CLI interface

* Additional checks

* Fix

* Trigger Build

* Add AWQ description

* Add trust remote code argument

* Black

* Add note about possibility of skipping AWQ

* Removed saving to temporary directory; added core property handling for OVModelForCausalLM

* Revert "Removed saving to temporary directory; added core property handling for OVModelForCausalLM"

This reverts commit bcc4665.

* Add saving intermediate weights in fp16; add removal of intermediate model if compression fails

* Trigger checks

* Trigger checks

* Trigger checks

* Fix test

* Refactor applying quantization with dataset

* Bring back quantization_config parameter

* Trigger checks

* Apply comment

* Save tokenizer

* Export CausalLM tokenizer

* Remove unneccessary if

* Remove extra variable

* ruff

* Ruff 2

* Introduce a separate function to tokenizer conversion

* Black
  • Loading branch information
nikita-savelyevv authored Jun 6, 2024
1 parent f06f504 commit 6888c0a
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 74 deletions.
87 changes: 63 additions & 24 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from typing import TYPE_CHECKING, Optional

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from transformers.utils.quantization_config import QuantizationMethod

from ...exporters import TasksManager
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from ..base import BaseOptimumCLICommand, CommandInfo


Expand Down Expand Up @@ -128,6 +130,33 @@ def parse_args_openvino(parser: "ArgumentParser"):
"compression is applied, they are compressed to INT8."
),
)
optional_group.add_argument(
"--awq",
action="store_true",
default=None,
help=(
"Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but requires "
"additional time for tuning weights on a calibration dataset. To run AWQ, please also provide a dataset "
"argument. Note: it's possible that there will be no matching patterns in the model to apply AWQ, in such "
"case it will be skipped."
),
)
optional_group.add_argument(
"--sensitivity-metric",
type=str,
default=None,
help=(
"The sensitivity metric for assigning quantization precision to layers. Can be one of the following: "
"['weight_quantization_error', 'hessian_input_activation', 'mean_activation_variance', "
"'max_activation_variance', 'mean_activation_magnitude']."
),
)
optional_group.add_argument(
"--num-samples",
type=int,
default=None,
help="The maximum number of samples to take from the dataset for quantization.",
)
optional_group.add_argument(
"--disable-stateful",
action="store_true",
Expand Down Expand Up @@ -180,7 +209,7 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_openvino(parser)

def run(self):
from ...exporters.openvino.__main__ import main_export
from ...exporters.openvino.__main__ import infer_task, main_export, maybe_convert_tokenizers
from ...intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS, OVConfig

if self.args.fp16:
Expand Down Expand Up @@ -208,6 +237,10 @@ def run(self):
and self.args.group_size is None
and self.args.sym is None
and self.args.all_layers is None
and self.args.dataset is None
and self.args.num_samples is None
and self.args.awq is None
and self.args.sensitivity_metric is None
and self.args.model in _DEFAULT_4BIT_CONFIGS
):
quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model]
Expand All @@ -218,6 +251,10 @@ def run(self):
"sym": self.args.sym or False,
"group_size": -1 if is_int8 else self.args.group_size,
"all_layers": None if is_int8 else self.args.all_layers,
"dataset": self.args.dataset,
"num_samples": self.args.num_samples,
"quant_method": QuantizationMethod.AWQ if self.args.awq else None,
"sensitivity_metric": self.args.sensitivity_metric,
}

if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
Expand All @@ -226,7 +263,6 @@ def run(self):
)
quantization_config["sym"] = "asym" not in self.args.weight_format
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
quantization_config["dataset"] = self.args.dataset
ov_config = OVConfig(quantization_config=quantization_config)

library_name = TasksManager.infer_library_from_model(self.args.model, library_name=self.args.library)
Expand All @@ -240,12 +276,11 @@ def run(self):
if self.args.convert_tokenizer:
logger.warning("`--convert-tokenizer` option is deprecated. Tokenizer will be converted by default.")

if (
library_name == "diffusers"
and ov_config
and ov_config.quantization_config
and ov_config.quantization_config.dataset is not None
):
quantization_config = ov_config.quantization_config if ov_config else None
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None
task = infer_task(self.args.task, self.args.model)

if library_name == "diffusers" and quantize_with_dataset:
if not is_diffusers_available():
raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models"))

Expand All @@ -270,25 +305,29 @@ def run(self):
else:
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")

model = model_cls.from_pretrained(
self.args.model, export=True, quantization_config=ov_config.quantization_config
model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=quantization_config)
model.save_pretrained(self.args.output)
if not self.args.disable_convert_tokenizer:
maybe_convert_tokenizers(library_name, self.args.output, model)
elif task.startswith("text-generation") and quantize_with_dataset:
from optimum.intel import OVModelForCausalLM

# To quantize a text-generation model with a dataset, an instantiated OVModelForCausalLM is required
model = OVModelForCausalLM.from_pretrained(
self.args.model,
export=True,
quantization_config=quantization_config,
stateful=not self.args.disable_stateful,
trust_remote_code=self.args.trust_remote_code,
)
model.save_pretrained(self.args.output)

if self.args.disable_convert_tokenizer:
return

# avoid import when using other exporters (IPEX, INC)
from ...exporters.openvino.convert import export_tokenizer

output = Path(self.args.output)
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
export_tokenizer(tokenizer, output / "tokenizer")

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
export_tokenizer(tokenizer_2, output / "tokenizer_2")
maybe_save_preprocessors(self.args.model, self.args.output, trust_remote_code=self.args.trust_remote_code)
if not self.args.disable_convert_tokenizer:
preprocessors = maybe_load_preprocessors(
self.args.model, trust_remote_code=self.args.trust_remote_code
)
maybe_convert_tokenizers(library_name, self.args.output, preprocessors=preprocessors)
else:
# TODO : add input shapes
main_export(
Expand Down
86 changes: 50 additions & 36 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@
logger = logging.getLogger(__name__)


def infer_task(task, model_name_or_path):
task = TasksManager.map_from_synonym(task)
if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
except KeyError as e:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
except RequestsConnectionError as e:
raise RequestsConnectionError(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
return task


def main_export(
model_name_or_path: str,
output: Union[str, Path],
Expand Down Expand Up @@ -174,7 +190,7 @@ def main_export(
ov_config = OVConfig(quantization_config=q_config)

original_task = task
task = TasksManager.map_from_synonym(task)
task = infer_task(task, model_name_or_path)
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
library_name_is_not_provided = library_name is None
library_name = TasksManager.infer_library_from_model(
Expand All @@ -188,18 +204,6 @@ def main_export(
)
library_name = "transformers"

if task == "auto":
try:
task = TasksManager.infer_task_from_model(model_name_or_path)
except KeyError as e:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)
except RequestsConnectionError as e:
raise RequestsConnectionError(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

do_gptq_patching = False
custom_architecture = False
loading_kwargs = {}
Expand Down Expand Up @@ -360,36 +364,46 @@ class StoreAttr(object):
**kwargs_shapes,
)

# hide openvino import when using other exporters
from optimum.exporters.openvino.convert import export_tokenizer
if convert_tokenizer:
maybe_convert_tokenizers(library_name, output, model, preprocessors)

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model

if convert_tokenizer and is_openvino_tokenizers_available():
if library_name != "diffusers":
tokenizer = next(
(preprocessor for preprocessor in preprocessors if isinstance(preprocessor, PreTrainedTokenizerBase)),
None,
)

if tokenizer is not None:
def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None):
"""
Tries to convert tokenizers to OV format and export them to disk.
Arguments:
library_name (`str`):
The library name.
output (`Path`):
Path to save converted tokenizers to.
model (`PreTrainedModel`, *optional*, defaults to None):
Model instance.
preprocessors (`Iterable`, *optional*, defaults to None):
Iterable possibly containing tokenizers to be converted.
"""
from optimum.exporters.openvino.convert import export_tokenizer

if is_openvino_tokenizers_available():
if library_name != "diffusers" and preprocessors:
tokenizer = next(filter(lambda it: isinstance(it, PreTrainedTokenizerBase), preprocessors), None)
if tokenizer:
try:
export_tokenizer(tokenizer, output)
except Exception as exception:
logger.warning(
"Could not load tokenizer using specified model ID or path. OpenVINO tokenizer/detokenizer "
f"models won't be generated. Exception: {exception}"
)
else:
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
export_tokenizer(tokenizer, output / "tokenizer")

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
export_tokenizer(tokenizer_2, output / "tokenizer_2")
elif convert_tokenizer and not is_openvino_tokenizers_available():
elif model:
for tokenizer_name in ("tokenizer", "tokenizer_2"):
tokenizer = getattr(model, tokenizer_name, None)
if tokenizer:
export_tokenizer(tokenizer, output / tokenizer_name)
else:
logger.warning("Tokenizer won't be converted.")

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def fix_op_names_duplicates(model: openvino.runtime.Model):
if file_name.suffix == ".onnx":
model = fix_op_names_duplicates(model) # should be called during model conversion to IR

# TODO: remove this way of applying quantization; instead apply it after instance of OVModel* is loaded
if quantization_config:
if not is_nncf_available():
raise ImportError(
Expand Down
22 changes: 9 additions & 13 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,17 +752,7 @@ def _from_pretrained(
local_files_only=local_files_only,
)

if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)

quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

load_in_4bit = quantization_config.bits == 4 if quantization_config else False

model = cls.load_model(
model_cache_path,
quantization_config=None if load_in_4bit else quantization_config,
)
model = cls.load_model(model_cache_path)

model_type = config.model_type.replace("_", "-")
if model_type == "bloom":
Expand All @@ -772,7 +762,12 @@ def _from_pretrained(
else:
init_cls = cls

enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

enable_compilation = kwargs.pop("compile", True) and not quantization_config

try:
generation_config = GenerationConfig.from_pretrained(
model_id,
Expand All @@ -785,6 +780,7 @@ def _from_pretrained(
kwargs["generation_config"] = generation_config
except Exception:
pass

causal_model = init_cls(
model=model,
config=config,
Expand All @@ -794,7 +790,7 @@ def _from_pretrained(
**kwargs,
)

if load_in_4bit:
if quantization_config:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights requires nncf, please install it with `pip install nncf`"
Expand Down
12 changes: 11 additions & 1 deletion tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class OVCLIExportTestCase(unittest.TestCase):
("text-generation-with-past", "opt125m", "int4_sym_g64", 62, 86),
("text-generation-with-past", "opt125m", "int4_asym_g64", 62, 86),
("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32),
(
"text-generation-with-past",
"llama_awq",
"int4 --ratio 1.0 --sym --group-size 16 --awq --dataset wikitext2 --num-samples 100 "
"--sensitivity-metric max_activation_variance",
4,
28,
),
]

def _openvino_export(
Expand Down Expand Up @@ -197,17 +205,19 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in
@parameterized.expand(TEST_4BIT_CONFIGURATONS)
def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_int8: int, expected_int4: int):
with TemporaryDirectory() as tmpdir:
subprocess.run(
result = subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}",
shell=True,
check=True,
capture_output=True,
)
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

_, num_int8, num_int4 = get_num_quantized_nodes(model)
self.assertEqual(expected_int8, num_int8)
self.assertEqual(expected_int4, num_int4)
self.assertTrue("--awq" not in option or b"Applying AWQ" in result.stdout)

def test_exporters_cli_help(self):
subprocess.run(
Expand Down

0 comments on commit 6888c0a

Please sign in to comment.