diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index c6d5e7bac0..bb3d2fe8c8 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -22,6 +22,7 @@ import torch from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import EntryNotFoundError from neural_compressor.utils.pytorch import load from transformers import ( AutoConfig, @@ -40,6 +41,7 @@ ) from transformers.modeling_utils import no_init_weights from transformers.models.auto.auto_factory import _get_model_class +from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils.generic import ContextManagers from optimum.intel.generation import BaseModelForCausalLM @@ -47,7 +49,7 @@ from ...modeling_base import OptimizedModel from ..utils.import_utils import _torch_version, is_itrex_available, is_torch_version from .configuration import INCConfig -from .utils import WEIGHTS_NAME +from .utils import QUANTIZATION_CONFIG_NAME logger = logging.getLogger(__name__) @@ -119,33 +121,70 @@ def _from_pretrained( raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") token = use_auth_token - model_name_or_path = kwargs.pop("model_name_or_path", None) - if model_name_or_path is not None: - logger.warning("`model_name_or_path` is deprecated please use `model_id`") - model_id = model_id or model_name_or_path - model_path = Path(model_id) - - if model_path.is_dir(): - model_cache_path = model_path / file_name + is_local = model_path.is_dir() + model_cache_path = None + inc_config = None + msg = None + if is_local: + if (model_path / subfolder / SAFE_WEIGHTS_NAME).is_file(): + file_name = SAFE_WEIGHTS_NAME + elif not (model_path / subfolder / file_name).is_file(): + raise EnvironmentError( + f"Error no file named {SAFE_WEIGHTS_NAME} or {file_name} found in directory {model_path / subfolder}" + ) + model_cache_path = model_path / subfolder / file_name else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - token=token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) + # Try download safetensors if exist + try: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=SAFE_WEIGHTS_NAME, + subfolder=subfolder, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + pass + + if model_cache_path is None: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=file_name, + subfolder=subfolder, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) model_save_dir = Path(model_cache_path).parent - inc_config = None - msg = None + if is_itrex_available(): - try: - quantization_config = PretrainedConfig.from_pretrained(model_save_dir / "quantize_config.json") + quantization_config_path = None + if is_local: + quantization_config_path = model_path / subfolder / QUANTIZATION_CONFIG_NAME + else: + try: + quantization_config_path = hf_hub_download( + repo_id=model_id, + filename=QUANTIZATION_CONFIG_NAME, + subfolder=subfolder, + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + pass + + if quantization_config_path and Path(quantization_config_path).is_file(): + quantization_config = PretrainedConfig.from_pretrained(quantization_config_path) algorithm = getattr(quantization_config, "quant_method", None) if algorithm in {"rtn", "gptq", "awq", "autoround"}: from intel_extension_for_transformers.transformers.modeling.modeling_auto import ( @@ -154,7 +193,7 @@ def _from_pretrained( _BaseQBitsAutoModelClass.ORIG_MODEL = cls.auto_model_class - return _BaseQBitsAutoModelClass.from_pretrained( + model = _BaseQBitsAutoModelClass.from_pretrained( pretrained_model_name_or_path=model_id, token=token, revision=revision, @@ -163,12 +202,16 @@ def _from_pretrained( local_files_only=local_files_only, subfolder=subfolder, trust_remote_code=trust_remote_code, + use_neural_speed=False, **kwargs, ) - except EnvironmentError: - msg = "The model is not quantized with weight-only quantization." + + return cls( + model, config=config, model_save_dir=model_save_dir, q_config=quantization_config, **kwargs + ) + try: - inc_config = INCConfig.from_pretrained(model_id) + inc_config = INCConfig.from_pretrained(model_id, subfolder=subfolder, revision=revision) if not is_torch_version("==", inc_config.torch_version): msg = f"Quantized model was obtained with torch version {inc_config.torch_version} but {_torch_version} was found." logger.warning(f"{msg}") @@ -209,15 +252,19 @@ def _from_pretrained( ) def _save_pretrained(self, save_directory: Union[str, Path]): - output_path = os.path.join(save_directory, WEIGHTS_NAME) - if isinstance(self.model, torch.nn.Module): - state_dict = self.model.state_dict() - if self._q_config: - state_dict["best_configure"] = self._q_config - torch.save(state_dict, output_path) + # For ITREX model + if isinstance(self._q_config, PretrainedConfig): + self._q_config.to_json_file(os.path.join(save_directory, QUANTIZATION_CONFIG_NAME)) + self.model.save_pretrained(save_directory) + # For INC model the state dictionary needs to be modified to include the quantization parameters + else: + state_dict = self.model.state_dict() + if isinstance(self._q_config, dict): + state_dict["best_configure"] = self._q_config + torch.save(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) else: - torch.jit.save(self.model, output_path) + torch.jit.save(self.model, os.path.join(save_directory, WEIGHTS_NAME)) if self.inc_config: self.inc_config.save_pretrained(save_directory) diff --git a/optimum/intel/neural_compressor/utils.py b/optimum/intel/neural_compressor/utils.py index 3173f5e1c4..84c1d6dc29 100644 --- a/optimum/intel/neural_compressor/utils.py +++ b/optimum/intel/neural_compressor/utils.py @@ -28,6 +28,7 @@ CONFIG_NAME = "best_configure.yaml" +QUANTIZATION_CONFIG_NAME = "quantize_config.json" NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0" NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0" diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index e6ce4763f2..0c3e60969b 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -16,10 +16,12 @@ import os import tempfile import unittest +from pathlib import Path import torch from parameterized import parameterized from transformers import AutoTokenizer, pipeline, set_seed +from transformers.utils import SAFE_WEIGHTS_NAME from optimum.exporters import TasksManager from optimum.intel import ( # noqa @@ -37,7 +39,8 @@ INCStableDiffusionPipeline, INCTrainer, ) -from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, WEIGHTS_NAME +from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, QUANTIZATION_CONFIG_NAME, WEIGHTS_NAME +from optimum.intel.utils.import_utils import is_itrex_available os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -52,7 +55,7 @@ MODEL_NAMES_TO_TASK = ( - ("hf-internal-testing/tiny-random-gpt2", "text-generation"), + ("hf-internal-testing/tiny-random-GPT2LMHeadModel", "text-generation"), ("hf-internal-testing/tiny-random-BertForMaskedLM", "fill-mask"), ("hf-internal-testing/tiny-random-DistilBertForSequenceClassification", "text-classification"), ("hf-internal-testing/tiny-random-DebertaV2Model", "feature-extraction"), @@ -86,7 +89,7 @@ def test_compare_to_transformers(self, model_id, task): outputs = inc_model(**model_inputs) with tempfile.TemporaryDirectory() as tmpdirname: inc_model.save_pretrained(tmpdirname) - loaded_model = model_class.from_pretrained(tmpdirname, file_name=WEIGHTS_NAME) + loaded_model = model_class.from_pretrained(tmpdirname) outputs_loaded = loaded_model(**model_inputs) if task == "feature-extraction": @@ -143,3 +146,57 @@ def test_compare_with_and_without_past_key_values(self): self.assertEqual(outputs_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_without_pkv.shape[1], self.GENERATION_LENGTH) self.assertTrue(torch.equal(outputs_with_pkv, outputs_without_pkv)) + + @unittest.skipIf(not is_itrex_available(), reason="ITREX not available") + def test_saving_loading_woq_itrex_model(self): + model_name = "echarlaix/tiny-random-PhiForCausalLM" + subfolder = "itrex" + model = INCModelForCausalLM.from_pretrained(model_name, revision="itrex", subfolder=subfolder) + tokenizer = AutoTokenizer.from_pretrained(model_name, revision="itrex") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokens = tokenizer("This is a sample output", return_tensors="pt") + + with tempfile.TemporaryDirectory() as tmp_dir: + model_save_dir = Path(tmp_dir) / subfolder + model.save_pretrained(model_save_dir) + folder_contents = os.listdir(model_save_dir) + self.assertIn(SAFE_WEIGHTS_NAME, folder_contents) + self.assertIn(QUANTIZATION_CONFIG_NAME, folder_contents) + loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder) + + with torch.no_grad(): + outputs = model(**tokens) + loaded_outputs = loaded_model(**tokens) + + self.assertTrue("logits" in loaded_outputs) + self.assertIsInstance(loaded_outputs.logits, torch.Tensor) + self.assertTrue("past_key_values" in loaded_outputs) + self.assertIsInstance(loaded_outputs.past_key_values, tuple) + self.assertTrue(torch.allclose(outputs.logits, loaded_outputs.logits, atol=1e-5)) + + def test_saving_loading_inc_model(self): + model_name = "echarlaix/tiny-random-PhiForCausalLM" + subfolder = "inc" + model = INCModelForCausalLM.from_pretrained(model_name, revision="inc", subfolder=subfolder) + tokenizer = AutoTokenizer.from_pretrained(model_name, revision="inc") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokens = tokenizer("This is a sample output", return_tensors="pt") + + with tempfile.TemporaryDirectory() as tmp_dir: + model_save_dir = Path(tmp_dir) / subfolder + model.save_pretrained(model_save_dir) + folder_contents = os.listdir(model_save_dir) + self.assertIn(WEIGHTS_NAME, folder_contents) + self.assertIn("inc_config.json", folder_contents) + loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder) + self.assertIsInstance(loaded_model.inc_config, INCConfig) + + with torch.no_grad(): + outputs = model(**tokens) + loaded_outputs = loaded_model(**tokens) + + self.assertTrue("logits" in loaded_outputs) + self.assertIsInstance(loaded_outputs.logits, torch.Tensor) + self.assertTrue("past_key_values" in loaded_outputs) + self.assertIsInstance(loaded_outputs.past_key_values, tuple) + self.assertTrue(torch.allclose(outputs.logits, loaded_outputs.logits, atol=1e-5)) diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index da42586139..56f2a5bac3 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -47,7 +47,6 @@ from utils_tests import MODEL_NAMES, SEED, INCTestMixin, _generate_dataset from optimum.intel.utils.import_utils import is_torch_version, is_itrex_available - from optimum.intel import ( INCConfig, INCModelForCausalLM, diff --git a/tests/neural_compressor/utils_tests.py b/tests/neural_compressor/utils_tests.py index a6d09954f5..2106237589 100644 --- a/tests/neural_compressor/utils_tests.py +++ b/tests/neural_compressor/utils_tests.py @@ -81,7 +81,7 @@ "electra": "hf-internal-testing/tiny-random-electra", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel",