diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index 3a15214f99..16c01e7298 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -30,8 +30,13 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install cmake + pip install py-cpuinfo + pip install torch==2.1.0 torchaudio==2.1.0 torchvision==0.16 --extra-index-url https://download.pytorch.org/whl/cpu pip install .[neural-compressor,diffusers,tests] - pip install intel-extension-for-pytorch + pip install intel-extension-for-pytorch==2.1.100 + pip install intel-extension-for-transformers==1.3.2 + pip install peft - name: Test with Pytest run: | pytest tests/neural_compressor/ diff --git a/examples/neural_compressor/language-modeling/README.md b/examples/neural_compressor/language-modeling/README.md index b005bb78a4..80d7a25d16 100644 --- a/examples/neural_compressor/language-modeling/README.md +++ b/examples/neural_compressor/language-modeling/README.md @@ -97,4 +97,4 @@ respectively `dynamic`, `static`, `weight_only` or `aware_training`. The flag `--verify_loading` can be passed along to verify that the resulting quantized model can be loaded correctly. -> **_Note:_** `weight_only` quantization_approach requires neural-compressor >= 2.3 +> **_Note:_** `weight_only` quantization_approach requires `neural-compressor` >= 2.3 and `intel-extension-for-transformers` >= 1.3. diff --git a/examples/neural_compressor/language-modeling/requirements.txt b/examples/neural_compressor/language-modeling/requirements.txt index 410b038891..ec38e83d2d 100644 --- a/examples/neural_compressor/language-modeling/requirements.txt +++ b/examples/neural_compressor/language-modeling/requirements.txt @@ -3,3 +3,5 @@ torch >= 1.9 datasets >= 1.8.0 sentencepiece != 0.1.92 protobuf +intel-extension-for-transformers >= 1.3 +peft diff --git a/examples/neural_compressor/language-modeling/run_clm.py b/examples/neural_compressor/language-modeling/run_clm.py index cbc523b663..ef24616307 100644 --- a/examples/neural_compressor/language-modeling/run_clm.py +++ b/examples/neural_compressor/language-modeling/run_clm.py @@ -57,6 +57,14 @@ from transformers.utils.versions import require_version from optimum.intel.neural_compressor import INCModelForCausalLM, INCQuantizer, INCTrainer +from optimum.intel.utils.import_utils import ( + INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR, + is_intel_extension_for_transformers_available, +) + + +if is_intel_extension_for_transformers_available(): + from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -143,7 +151,9 @@ class OptimizationArguments: ) quantization_approach: str = field( default="dynamic", - metadata={"help": "Quantization approach. Supported approach are static, dynamic and aware_training."}, + metadata={ + "help": "Quantization approach. Supported approach are static, dynamic aware_training and weight_only." + }, ) smooth_quant: bool = field( default=False, @@ -196,9 +206,13 @@ class OptimizationArguments: default=False, metadata={"help": "Whether or not to verify the loading of the quantized model."}, ) - bits: int = field( - default=8, - metadata={"help": "Bits for weight only quantization, 1-8 bits."}, + bits: str = field( + default="4", + metadata={"help": "Bits number of weight for weight only quantization. 1~8 bits."}, + ) + weight_dtype: str = field( + default="int4_clip", + metadata={"help": "weight dtype for weight only quantization."}, ) group_size: int = field( default=-1, @@ -214,10 +228,29 @@ class OptimizationArguments: ) quantization_methodology: str = field( default="RTN", + metadata={"help": "Quantization methodology for weight only quantization. Choose from 'RTN' and 'GPTQ'."}, + ) + damp_percent: float = field( + default=0.01, metadata={ - "help": "Quantization methodology for weight only quantization. Choose from 'RTN', 'AWQ' and 'GPTQ'." + "help": "Percentage of Hessian's diagonal values average, which will be added to Hessian's diagonal to increase numerical stability, used for GPTQ quantization" }, ) + gptq_block_size: int = field( + default=128, + metadata={"help": "Block size. sub weight matrix size to run GPTQ."}, + ) + num_calibration_samples: int = field( + default=128, metadata={"help": "Number of examples to use for the GPTQ calibration step."} + ) + use_max_length: bool = field( + default=False, + metadata={"help": "Set all sequence length to be same length of args.gptq_pad_max_length"}, + ) + pad_max_length: int = field( + default=2048, + metadata={"help": "Calibration dataset sequence max length, this should align with your model config"}, + ) @dataclass @@ -625,26 +658,30 @@ def compute_metrics(eval_preds): else: recipes = {} if optim_args.quantization_approach == "weight_only": - op_type_dict = { - ".*": { - "weight": { - "bits": optim_args.bits, - "group_size": optim_args.group_size, - "scheme": optim_args.weight_only_scheme, - "algorithm": optim_args.quantization_methodology, - }, - }, - } + if not is_intel_extension_for_transformers_available(): + raise ImportError(INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR.format("WeightOnly quantization")) + if optim_args.apply_pruning or optim_args.apply_distillation: + raise ValueError("Weight only quantization and pruning or distillation cannot be combined.") if optim_args.quantization_methodology == "GPTQ": - gptq_args = { - "pad_max_length": block_size, + algorithm_args = { + "act_order": False, + "percdamp": optim_args.damp_percent, + "block_size": optim_args.gptq_block_size, + "nsamples": optim_args.num_calibration_samples, + "use_max_length": optim_args.use_max_length, + "pad_max_length": optim_args.pad_max_length, } - recipes.update({"gptq_args": gptq_args}) + quantization_config = WeightOnlyQuantConfig( + weight_dtype=optim_args.weight_dtype, + group_size=optim_args.group_size, + scheme=optim_args.weight_only_scheme, + algorithm=optim_args.quantization_methodology, + algorithm_args=algorithm_args if optim_args.quantization_methodology == "GPTQ" else None, + ) else: - op_type_dict = {} - quantization_config = PostTrainingQuantConfig( - approach=optim_args.quantization_approach, op_type_dict=op_type_dict, recipes=recipes - ) + quantization_config = PostTrainingQuantConfig( + approach=optim_args.quantization_approach, recipes=recipes + ) if optim_args.apply_pruning: if optim_args.end_step is None: @@ -732,15 +769,15 @@ def compute_metrics(eval_preds): quantizer.quantize( quantization_config=quantization_config, save_directory=training_args.output_dir, - calibration_dataset=train_dataset - if optim_args.quantization_approach in ["static", "weight_only"] - else None, - batch_size=1 # batch_size > 1 for GPTQ is WIP - if optim_args.quantization_approach == "weight_only" and optim_args.quantization_methodology == "GPTQ" - else training_args.per_device_train_batch_size, - weight_only=True if optim_args.quantization_approach == "weight_only" else False, + calibration_dataset=( + train_dataset if optim_args.quantization_approach in ["static", "weight_only"] else None + ), + batch_size=( + 1 if optim_args.quantization_approach == "weight_only" else training_args.per_device_train_batch_size + ), ) trainer.model = quantizer._quantized_model + if optim_args.apply_quantization and optim_args.verify_loading: loaded_model = INCModelForCausalLM.from_pretrained(training_args.output_dir) tokens = tokenizer("This is a sample input", return_tensors="pt") diff --git a/examples/neural_compressor/text-generation/run_generation.py b/examples/neural_compressor/text-generation/run_generation.py index e06bba4102..9966a73c10 100755 --- a/examples/neural_compressor/text-generation/run_generation.py +++ b/examples/neural_compressor/text-generation/run_generation.py @@ -368,9 +368,7 @@ def calibration_fn(p_model): args.length = adjust_length_to_model( args.length, - max_sequence_length=model.config.max_position_embeddings - if hasattr(model.config, "max_position_embeddings") - else 0, + max_sequence_length=getattr(model.config, "max_position_embeddings", 0), ) logger.info(args) diff --git a/optimum/intel/neural_compressor/__init__.py b/optimum/intel/neural_compressor/__init__.py index a7170120b7..2daecfbc93 100644 --- a/optimum/intel/neural_compressor/__init__.py +++ b/optimum/intel/neural_compressor/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils.import_utils import is_diffusers_available +from ..utils.import_utils import is_diffusers_available, is_intel_extension_for_transformers_available from .configuration import INCConfig from .modeling_base import ( INCModel, diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index b74bd305bc..c37542fdef 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -43,7 +43,12 @@ from optimum.intel.generation import BaseModelForCausalLM from ...modeling_base import OptimizedModel -from ..utils.import_utils import _torch_version, is_torch_version +from ..utils.import_utils import ( + _torch_version, + is_intel_extension_for_transformers_available, + is_torch_version, + requires_backends, +) from .configuration import INCConfig from .utils import WEIGHTS_NAME @@ -63,6 +68,11 @@ """ +if is_intel_extension_for_transformers_available(): + from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM as ITREX_WOQ_MODEL + from intel_extension_for_transformers.transformers.utils import WeightOnlyQuantConfig + + class INCModel(OptimizedModel): auto_model_class = AutoModel export_feature = "feature-extraction" @@ -131,6 +141,25 @@ def _from_pretrained( model_save_dir = Path(model_cache_path).parent inc_config = None msg = None + try: + requires_backends(cls, ["intel_extension_for_transformers"]) + quantization_config = WeightOnlyQuantConfig.from_pretrained(model_id) + if getattr( + quantization_config, "algorithm", None + ) is not None and quantization_config.algorithm.lower() in ["rtn", "gptq", "awq", "autoaround"]: + return ITREX_WOQ_MODEL.from_pretrained( + pretrained_model_name_or_path=model_id, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + local_files_only=local_files_only, + subfolder=subfolder, + trust_remote_code=trust_remote_code, + **kwargs, + ) + except EnvironmentError: + msg = "The model is not quantized with weight-only quantization." try: inc_config = INCConfig.from_pretrained(model_id) if not is_torch_version("==", inc_config.torch_version): diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 5eb4da8cef..6b5ea38a16 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -15,6 +15,7 @@ import copy import inspect import logging +import types from enum import Enum from itertools import chain from pathlib import Path @@ -46,10 +47,15 @@ from ..utils.constant import _TASK_ALIASES, MIN_QDQ_ONNX_OPSET, ONNX_WEIGHTS_NAME, WEIGHTS_NAME from ..utils.import_utils import ( + _intel_extension_for_transformers_version, _ipex_version, _neural_compressor_version, + _torch_version, + is_intel_extension_for_transformers_available, + is_intel_extension_for_transformers_version, is_ipex_version, is_neural_compressor_version, + is_torch_version, ) from .configuration import INCConfig from .modeling_base import ( # noqa @@ -65,6 +71,28 @@ from .utils import INCDataLoader, _cfgs_to_fx_cfgs +if is_intel_extension_for_transformers_available(): + INTEL_EXTENSION_FOR_TRANSFORMERS_MINIMUM_VERSION = "1.3.2" + if is_intel_extension_for_transformers_version("!=", INTEL_EXTENSION_FOR_TRANSFORMERS_MINIMUM_VERSION): + raise ImportError( + f"Found an incompatible version of `intel-extension-for-transformers`. Found version {_intel_extension_for_transformers_version}, " + f"but only version {INTEL_EXTENSION_FOR_TRANSFORMERS_MINIMUM_VERSION} is supported." + ) + TORCH_VERSION = "2.1.0" + if is_torch_version("!=", TORCH_VERSION): + raise ImportError( + f"Found an incompatible version of `torch`. Found version {_torch_version}, " + f"but only version {TORCH_VERSION} is supported." + ) + + from intel_extension_for_transformers.llm.quantization.utils import convert_to_quantized_model + from intel_extension_for_transformers.transformers.modeling.modeling_auto import save_low_bit + from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig + + Config = Union[PostTrainingQuantConfig, WeightOnlyQuantConfig] +else: + Config = PostTrainingQuantConfig + logger = logging.getLogger(__name__) NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0" @@ -132,7 +160,7 @@ def from_pretrained(cls, model: PreTrainedModel, **kwargs): def quantize( self, - quantization_config: "PostTrainingQuantConfig", + quantization_config: Config, save_directory: Union[str, Path], calibration_dataset: Dataset = None, batch_size: int = 8, @@ -146,7 +174,7 @@ def quantize( Quantize a model given the optimization specifications defined in `quantization_config`. Args: - quantization_config (`PostTrainingQuantConfig`): + quantization_config (`Union[PostTrainingQuantConfig, WeightOnlyQuantConfig]`): The configuration containing the parameters related to quantization. save_directory (`Union[str, Path]`): The directory where the quantized model should be saved. @@ -165,8 +193,11 @@ def quantize( save_directory = Path(save_directory) save_directory.mkdir(parents=True, exist_ok=True) save_onnx_model = kwargs.pop("save_onnx_model", False) + device = kwargs.pop("device", "cpu") + use_cpu = device == torch.device("cpu") or device == "cpu" + use_xpu = device == torch.device("xpu") or device == "xpu" - if save_onnx_model and isinstance(self._original_model, ORTModel): + if save_onnx_model and (isinstance(self._original_model, ORTModel) or weight_only): save_onnx_model = False logger.warning("Model provided is an ONNX model, `save_onnx_model` is set to False") @@ -174,24 +205,33 @@ def quantize( calibration_dataloader = None self._set_task() - if weight_only: + if weight_only or not isinstance(quantization_config, PostTrainingQuantConfig): # check neural-compressor version if is_neural_compressor_version("<", NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION): raise ImportError( f"Found an incompatible version of neural-compressor. Found version {_neural_compressor_version}, " f"but only version {NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION} or higher supports weight-only quantization." ) + if not is_intel_extension_for_transformers_available(): + raise ImportError( + "Didn't find out intel-etension-for-transformers package. " + "Please install packages: pip install intel-etension-for-transformers and pip install peft." + ) - # If op_type_dict of quantization_config is not defined, it will use default values for weight-only quantization: - # {"bits": 4, "group_size": 32, "scheme": "sym", "algorithm": "RTN"} - if isinstance(quantization_config.op_type_dict, dict) and len(quantization_config.op_type_dict) > 0: - algo = [] - for _, val in quantization_config.op_type_dict.items(): - algo += val.get("weight", {}).get("algorithm", ["RTN"]) + if quantization_config is None: + quantization_config = WeightOnlyQuantConfig() + algo = "RTN" + elif isinstance(quantization_config, WeightOnlyQuantConfig): + algo = quantization_config.algorithm else: - algo = ["RTN"] + raise TypeError( + f"For weight-only quantization, `quantization_config` should be an instance of `WeightOnlyQuantConfig`, but got: {type(quantization_config)} instead." + ) + + if algo not in ["RTN", "GPTQ"]: + raise ValueError(f"Weight-only quantization is only support RTN and GPTQ algorithm now!But got {algo}") - if calibration_dataset is None and ("GPTQ" in algo or "AWQ" in algo): + if calibration_dataset is None and quantization_config.tokenizer is None and ("GPTQ" in algo): raise ValueError( "Weight-only quantization needs a calibration dataset for both GPTQ and AWQ methodologies." ) @@ -206,6 +246,9 @@ def quantize( data_collator=data_collator, use_label=False if "GPTQ" in algo else True, ) + quantization_config.calib_dataloader = calibration_dataloader + + save_onnx_model = False elif INCQuantizationMode(quantization_config.approach) == INCQuantizationMode.STATIC: # Since PyTorch fx trace does not really require an example_inputs, only need calibration_dataset or calibration_fn here. @@ -238,7 +281,8 @@ def quantize( save_onnx_model = False if ( - quantization_config.backend == "ipex" + isinstance(quantization_config, PostTrainingQuantConfig) + and quantization_config.backend == "ipex" and is_ipex_version("<", IPEX_MINIMUM_VERSION) and "generation" in self.task ): @@ -247,76 +291,97 @@ def quantize( f"but only version {IPEX_MINIMUM_VERSION} or higher is supported." ) - if isinstance(self._original_model.config, PretrainedConfig): - self._original_model.config.backend = quantization_config.backend - - if isinstance(self._original_model, ORTModel): - # TODO : enable seq2seq models - if isinstance(self._original_model, ORTModelForConditionalGeneration): - raise RuntimeError("ORTModelForConditionalGeneration not supported for quantization") - - if isinstance(self._original_model, ORTModelForCausalLM): - model_or_path = self._original_model.onnx_paths - if len(model_or_path) > 1: - raise RuntimeError( - f"Too many ONNX model files were found in {self._original_model.onnx_paths}, only `use_cache=False` is supported" - ) - model_or_path = str(model_or_path[0]) - default_name = ONNX_DECODER_NAME - else: - model_or_path = str(self._original_model.model_path) + if not isinstance(quantization_config, PostTrainingQuantConfig): + if use_cpu: + # will remove after intel-extension-for-transformers 1.3.3 release. + quantization_config.device = "cpu" + quantization_config.post_init() + elif use_xpu: + # will remove after intel-extension-for-transformers 1.3.3 release. + quantization_config.device = "xpu" + quantization_config.post_init_xpu() + self._quantized_model = convert_to_quantized_model( + self._original_model, quantization_config, device=quantization_config.device + ) + # will remove after intel-extension-for-transformers 1.3.3 release. + if hasattr(quantization_config, "calib_dataloader"): + quantization_config.calib_dataloader = None + self._quantized_model.quantization_config = quantization_config + self._quantized_model.save_pretrained = types.MethodType(save_low_bit, self._quantized_model) + # Save the quantized model + self._quantized_model.save_pretrained(save_directory) else: - model_or_path = self._original_model - - compressed_model = fit( - model_or_path, - conf=quantization_config, - calib_dataloader=calibration_dataloader, - eval_func=self.eval_fn, - calib_func=self.calibration_fn, - ) - - if not hasattr(compressed_model, "_model") or compressed_model._model is None: - raise RuntimeError( - "The maximum number of trials specified has been reached and no quantized model meeting the specified" - " accuracy tolerance has been found. Either the tolerance or the number of trials need to be increased." + if isinstance(self._original_model.config, PretrainedConfig): + self._original_model.config.backend = quantization_config.backend + + if isinstance(self._original_model, ORTModel): + # TODO : enable seq2seq models + if isinstance(self._original_model, ORTModelForConditionalGeneration): + raise RuntimeError("ORTModelForConditionalGeneration not supported for quantization") + + if isinstance(self._original_model, ORTModelForCausalLM): + model_or_path = self._original_model.onnx_paths + if len(model_or_path) > 1: + raise RuntimeError( + f"Too many ONNX model files were found in {self._original_model.onnx_paths}, only `use_cache=False` is supported" + ) + model_or_path = str(model_or_path[0]) + default_name = ONNX_DECODER_NAME + else: + model_or_path = str(self._original_model.model_path) + else: + model_or_path = self._original_model + + compressed_model = fit( + model_or_path, + conf=quantization_config, + calib_dataloader=calibration_dataloader, + eval_func=self.eval_fn, + calib_func=self.calibration_fn, ) - if isinstance(self._original_model.config, PretrainedConfig): - # If backend is IPEX, then the quantized model is JIT model which will drop the config attribute, - # so need set config from original_model. - model_config = copy.deepcopy(self._original_model.config) - model_config.torch_dtype = "int8" - if isinstance(compressed_model, IPEXModel): - model_config.torchscript = True - model_config.backend = "ipex" - elif not isinstance(compressed_model, ONNXModel): - compressed_model._model.config = model_config - model_config.save_pretrained(save_directory) - - self._quantized_model = compressed_model._model - - if save_onnx_model: - model_type = self._original_model.config.model_type.replace("_", "-") - model_name = getattr(self._original_model, "name", None) - onnx_config_class = TasksManager.get_exporter_config_constructor( - exporter="onnx", - model=self._original_model, - task=self.task, - model_type=model_type, - model_name=model_name, - ) - onnx_config = onnx_config_class(self._original_model.config) - compressed_model.eval() - output_onnx_path = save_directory.joinpath(ONNX_WEIGHTS_NAME) - # Export the compressed model to the ONNX format - self._onnx_export(compressed_model, onnx_config, output_onnx_path) - - output_path = save_directory.joinpath(file_name or default_name) - # Save the quantized model - self._save_pretrained(compressed_model, output_path) - quantization_config = INCConfig(quantization=quantization_config, save_onnx_model=save_onnx_model) - quantization_config.save_pretrained(save_directory) + if not hasattr(compressed_model, "_model") or compressed_model._model is None: + raise RuntimeError( + "The maximum number of trials specified has been reached and no quantized model meeting the specified" + " accuracy tolerance has been found. Either the tolerance or the number of trials need to be increased." + ) + + if isinstance(self._original_model.config, PretrainedConfig): + # If backend is IPEX, then the quantized model is JIT model which will drop the config attribute, + # so need set config from original_model. + model_config = copy.deepcopy(self._original_model.config) + model_config.torch_dtype = "int8" + if isinstance(compressed_model, IPEXModel): + model_config.torchscript = True + model_config.backend = "ipex" + elif not isinstance(compressed_model, ONNXModel): + compressed_model._model.config = model_config + model_config.save_pretrained(save_directory) + + self._quantized_model = compressed_model._model + + if save_onnx_model: + model_type = self._original_model.config.model_type.replace("_", "-") + model_name = getattr(self._original_model, "name", None) + onnx_config_class = TasksManager.get_exporter_config_constructor( + exporter="onnx", + model=self._original_model, + task=self.task, + model_type=model_type, + model_name=model_name, + ) + onnx_config = onnx_config_class(self._original_model.config) + compressed_model.eval() + output_onnx_path = save_directory.joinpath(ONNX_WEIGHTS_NAME) + # Export the compressed model to the ONNX format + self._onnx_export(compressed_model, onnx_config, output_onnx_path) + + output_path = save_directory.joinpath(file_name or default_name) + # Save the quantized model + self._save_pretrained(compressed_model, output_path) + quantization_config = INCConfig(quantization=quantization_config, save_onnx_model=save_onnx_model) + quantization_config.save_pretrained(save_directory) + return self._quantized_model @staticmethod def _save_pretrained(model: Union[PyTorchModel, IPEXModel], output_path: str): diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 1d5ce25086..08a9ec1f88 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -61,6 +61,16 @@ _neural_compressor_available = False +_intel_extension_for_transformers_available = importlib.util.find_spec("intel_extension_for_transformers") is not None +_intel_extension_for_transformers_version = "N/A" +if _intel_extension_for_transformers_available: + try: + _intel_extension_for_transformers_version = importlib_metadata.version("intel_extension_for_transformers") + logging.warn("`transformers` version >= 4.31 is requirements by intel-extension-for-transformers.") + except importlib_metadata.PackageNotFoundError: + _intel_extension_for_transformers_available = False + + _ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None _ipex_version = "N/A" if _ipex_available: @@ -174,6 +184,10 @@ def is_neural_compressor_available(): return _neural_compressor_available +def is_intel_extension_for_transformers_available(): + return _intel_extension_for_transformers_available + + def is_ipex_available(): return _ipex_available @@ -253,6 +267,15 @@ def is_neural_compressor_version(operation: str, version: str): return compare_versions(parse(_neural_compressor_version), operation, version) +def is_intel_extension_for_transformers_version(operation: str, version: str): + """ + Compare the current intel_extension_for_transformers version to a given reference with an operation. + """ + if not _intel_extension_for_transformers_available: + return False + return compare_versions(parse(_intel_extension_for_transformers_version), operation, version) + + def is_openvino_version(operation: str, version: str): """ Compare the current OpenVINO version to a given reference with an operation. @@ -326,6 +349,11 @@ def is_timm_version(operation: str, version: str): `pip install neural-compressor`. Please note that you may need to restart your runtime after installation. """ +INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR = """ +{0} requires the intel-extension-for-transformers library but it was not found in your environment. You can install it with pip: +`pip install intel-extension-for-transformers` and `pip install peft`. Please note that you may need to restart your runtime after installation. +""" + DATASETS_IMPORT_ERROR = """ {0} requires the datasets library but it was not found in your environment. You can install it with pip: `pip install datasets`. Please note that you may need to restart your runtime after installation. @@ -343,6 +371,10 @@ def is_timm_version(operation: str, version: str): ("nncf", (is_nncf_available, NNCF_IMPORT_ERROR)), ("openvino", (is_openvino_available, OPENVINO_IMPORT_ERROR)), ("neural_compressor", (is_neural_compressor_available, NEURAL_COMPRESSOR_IMPORT_ERROR)), + ( + "intel_extension_for_transformers", + (is_intel_extension_for_transformers_available, INTEL_EXTENSION_FOR_TRANSFORMERS_IMPORT_ERROR), + ), ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), ] ) diff --git a/setup.py b/setup.py index dc5c417218..e80d0ea448 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,11 @@ QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"] EXTRAS_REQUIRE = { - "neural-compressor": ["neural-compressor>=2.2.0", "onnxruntime<1.15.0", "accelerate"], + "neural-compressor": [ + "neural-compressor>=2.2.0", + "onnxruntime<1.15.0", + "accelerate", + ], "openvino": ["openvino>=2023.3", "nncf>=2.8.1"], "openvino-tokenizers": ["openvino-tokenizers[transformers]"], "nncf": ["nncf>=2.8.1"], diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index 61a54128b4..88e203a517 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -17,6 +17,7 @@ import os import tempfile +import copy import unittest import evaluate @@ -44,7 +45,7 @@ set_seed, ) from utils_tests import SEED, INCTestMixin, _generate_dataset -from optimum.intel.utils.import_utils import is_torch_version +from optimum.intel.utils.import_utils import is_torch_version, is_intel_extension_for_transformers_available from optimum.intel import ( @@ -63,6 +64,8 @@ from optimum.onnxruntime import ORTModelForCausalLM, ORTModelForSequenceClassification from optimum.pipelines import ORT_SUPPORTED_TASKS +if is_intel_extension_for_transformers_available(): + from intel_extension_for_transformers.transformers.utils.config import WeightOnlyQuantConfig os.environ["CUDA_VISIBLE_DEVICES"] = "" set_seed(SEED) @@ -84,6 +87,13 @@ class OptimizationTest(INCTestMixin): "hf-internal-testing/tiny-random-GPTNeoForCausalLM", ) + WEIGHT_ONLY_CONFIG = ( + (False, "RTN", "int4_clip"), + (False, "GPTQ", "int4_clip"), + (False, "RTN", "int8"), + (True, "", ""), + ) + @parameterized.expand(SUPPORTED_ARCHITECTURES_DYNAMIC) def test_dynamic_quantization(self, task, model_name, expected_quantized_matmuls): quantization_config = PostTrainingQuantConfig(approach="dynamic") @@ -198,86 +208,43 @@ def test_ipex_static_quantization_with_smoothquant(self, task, model_name, expec load_ipex_model=True, ) - def test_weight_only_quantization(self): + @parameterized.expand(WEIGHT_ONLY_CONFIG) + @unittest.skipIf( + not is_intel_extension_for_transformers_available(), reason="Intel-extension-for-transformers not available!" + ) + def test_weight_only_quantization(self, no_config, algo, weight_dtype): model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" - op_type_dict = { - ".*": { - "weight": { - "bits": 8, - "group_size": -1, - "scheme": "sym", - "algorithm": "RTN", - }, - }, - } - quantization_config = PostTrainingQuantConfig(approach="weight_only", op_type_dict=op_type_dict) model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - quantizer = INCQuantizer.from_pretrained(model, task="text-generation") + quantizer = INCQuantizer.from_pretrained(copy.deepcopy(model), task="text-generation") calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=2) with tempfile.TemporaryDirectory() as tmp_dir: - quantizer.quantize( - quantization_config=quantization_config, - calibration_dataset=calibration_dataset, - save_directory=tmp_dir, - weight_only=True, - ) - q_model = AutoModelForCausalLM.from_pretrained(tmp_dir) - inp = torch.tensor([calibration_dataset[0]["input_ids"]]) - out = model(inp)[0] - q_out = q_model(inp)[0] - self.assertTrue(torch.all(torch.isclose(out, q_out, atol=5e-1))) - - op_type_dict = { - ".*": { - "weight": { - "bits": 8, - "group_size": -1, - "scheme": "sym", - "algorithm": "AWQ", - }, - }, - } - quantization_config = PostTrainingQuantConfig(approach="weight_only", op_type_dict=op_type_dict) - - with tempfile.TemporaryDirectory() as tmp_dir: - quantizer.quantize( - quantization_config=quantization_config, - calibration_dataset=calibration_dataset, - save_directory=tmp_dir, - weight_only=True, - ) - q_model = AutoModelForCausalLM.from_pretrained(tmp_dir) - inp = torch.tensor([calibration_dataset[0]["input_ids"]]) - out = model(inp)[0] - q_out = q_model(inp)[0] - self.assertTrue(torch.all(torch.isclose(out, q_out, atol=6e-1))) - - op_type_dict = { - ".*": { - "weight": { - "bits": 8, - "group_size": -1, - "scheme": "sym", - "algorithm": "GPTQ", - }, - }, - } - recipes = {"gptq_args": {"pad_max_length": len(calibration_dataset[0]["input_ids"])}} - quantization_config = PostTrainingQuantConfig( - approach="weight_only", op_type_dict=op_type_dict, recipes=recipes - ) - - with tempfile.TemporaryDirectory() as tmp_dir: - quantizer.quantize( - quantization_config=quantization_config, - calibration_dataset=calibration_dataset, - save_directory=tmp_dir, - weight_only=True, - ) - q_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + if not no_config: + if algo == "GPTQ": + algorithm_args = { + "percdamp": 0.01, + "act_order": False, + "scheme": "sym", + } + quantization_config = WeightOnlyQuantConfig( + algorithm=algo, + algorithm_args=algorithm_args if algo == "GPTQ" else None, + weight_dtype=weight_dtype, + ) + q_model = quantizer.quantize( + quantization_config=quantization_config, + calibration_dataset=calibration_dataset if algo == "GPTQ" else None, + save_directory=tmp_dir, + ) + else: + q_model = quantizer.quantize( + quantization_config=None, + save_directory=tmp_dir, + weight_only=True, # use RTN quantization method and NF4 weight data type is default. + ) + q_model = INCModelForCausalLM.from_pretrained(tmp_dir) inp = torch.tensor([calibration_dataset[0]["input_ids"]]) out = model(inp)[0] q_out = q_model(inp)[0] diff --git a/tests/openvino/test_modeling_basic.py b/tests/openvino/test_modeling_basic.py index a443c5fea7..9423ce5683 100644 --- a/tests/openvino/test_modeling_basic.py +++ b/tests/openvino/test_modeling_basic.py @@ -7,6 +7,7 @@ This test is meant to run quickly with tiny test models. More extensive tests are in test_modeling.py. """ + # ruff: noqa import gc