Skip to content

Commit

Permalink
Fix itrex WOQ model loading (#730)
Browse files Browse the repository at this point in the history
* Fix loading ITREX model

* add test

* fix loading WOQ and quantization config

* add test

* add revision and subfolder parameters when loading inc config

* style

* update test model id
  • Loading branch information
echarlaix authored May 28, 2024
1 parent 7b4e50f commit bfd0767
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 39 deletions.
115 changes: 81 additions & 34 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import EntryNotFoundError
from neural_compressor.utils.pytorch import load
from transformers import (
AutoConfig,
Expand All @@ -40,14 +41,15 @@
)
from transformers.modeling_utils import no_init_weights
from transformers.models.auto.auto_factory import _get_model_class
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from transformers.utils.generic import ContextManagers

from optimum.intel.generation import BaseModelForCausalLM

from ...modeling_base import OptimizedModel
from ..utils.import_utils import _torch_version, is_itrex_available, is_torch_version
from .configuration import INCConfig
from .utils import WEIGHTS_NAME
from .utils import QUANTIZATION_CONFIG_NAME


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,33 +121,70 @@ def _from_pretrained(
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token

model_name_or_path = kwargs.pop("model_name_or_path", None)
if model_name_or_path is not None:
logger.warning("`model_name_or_path` is deprecated please use `model_id`")
model_id = model_id or model_name_or_path

model_path = Path(model_id)

if model_path.is_dir():
model_cache_path = model_path / file_name
is_local = model_path.is_dir()
model_cache_path = None
inc_config = None
msg = None
if is_local:
if (model_path / subfolder / SAFE_WEIGHTS_NAME).is_file():
file_name = SAFE_WEIGHTS_NAME
elif not (model_path / subfolder / file_name).is_file():
raise EnvironmentError(
f"Error no file named {SAFE_WEIGHTS_NAME} or {file_name} found in directory {model_path / subfolder}"
)
model_cache_path = model_path / subfolder / file_name
else:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
subfolder=subfolder,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
# Try download safetensors if exist
try:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=SAFE_WEIGHTS_NAME,
subfolder=subfolder,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
pass

if model_cache_path is None:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
subfolder=subfolder,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)

model_save_dir = Path(model_cache_path).parent
inc_config = None
msg = None

if is_itrex_available():
try:
quantization_config = PretrainedConfig.from_pretrained(model_save_dir / "quantize_config.json")
quantization_config_path = None
if is_local:
quantization_config_path = model_path / subfolder / QUANTIZATION_CONFIG_NAME
else:
try:
quantization_config_path = hf_hub_download(
repo_id=model_id,
filename=QUANTIZATION_CONFIG_NAME,
subfolder=subfolder,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
pass

if quantization_config_path and Path(quantization_config_path).is_file():
quantization_config = PretrainedConfig.from_pretrained(quantization_config_path)
algorithm = getattr(quantization_config, "quant_method", None)
if algorithm in {"rtn", "gptq", "awq", "autoround"}:
from intel_extension_for_transformers.transformers.modeling.modeling_auto import (
Expand All @@ -154,7 +193,7 @@ def _from_pretrained(

_BaseQBitsAutoModelClass.ORIG_MODEL = cls.auto_model_class

return _BaseQBitsAutoModelClass.from_pretrained(
model = _BaseQBitsAutoModelClass.from_pretrained(
pretrained_model_name_or_path=model_id,
token=token,
revision=revision,
Expand All @@ -163,12 +202,16 @@ def _from_pretrained(
local_files_only=local_files_only,
subfolder=subfolder,
trust_remote_code=trust_remote_code,
use_neural_speed=False,
**kwargs,
)
except EnvironmentError:
msg = "The model is not quantized with weight-only quantization."

return cls(
model, config=config, model_save_dir=model_save_dir, q_config=quantization_config, **kwargs
)

try:
inc_config = INCConfig.from_pretrained(model_id)
inc_config = INCConfig.from_pretrained(model_id, subfolder=subfolder, revision=revision)
if not is_torch_version("==", inc_config.torch_version):
msg = f"Quantized model was obtained with torch version {inc_config.torch_version} but {_torch_version} was found."
logger.warning(f"{msg}")
Expand Down Expand Up @@ -209,15 +252,19 @@ def _from_pretrained(
)

def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)

if isinstance(self.model, torch.nn.Module):
state_dict = self.model.state_dict()
if self._q_config:
state_dict["best_configure"] = self._q_config
torch.save(state_dict, output_path)
# For ITREX model
if isinstance(self._q_config, PretrainedConfig):
self._q_config.to_json_file(os.path.join(save_directory, QUANTIZATION_CONFIG_NAME))
self.model.save_pretrained(save_directory)
# For INC model the state dictionary needs to be modified to include the quantization parameters
else:
state_dict = self.model.state_dict()
if isinstance(self._q_config, dict):
state_dict["best_configure"] = self._q_config
torch.save(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
else:
torch.jit.save(self.model, output_path)
torch.jit.save(self.model, os.path.join(save_directory, WEIGHTS_NAME))

if self.inc_config:
self.inc_config.save_pretrained(save_directory)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/neural_compressor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


CONFIG_NAME = "best_configure.yaml"
QUANTIZATION_CONFIG_NAME = "quantize_config.json"

NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0"
NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0"
Expand Down
63 changes: 60 additions & 3 deletions tests/neural_compressor/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import os
import tempfile
import unittest
from pathlib import Path

import torch
from parameterized import parameterized
from transformers import AutoTokenizer, pipeline, set_seed
from transformers.utils import SAFE_WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.intel import ( # noqa
Expand All @@ -37,7 +39,8 @@
INCStableDiffusionPipeline,
INCTrainer,
)
from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, WEIGHTS_NAME
from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS, QUANTIZATION_CONFIG_NAME, WEIGHTS_NAME
from optimum.intel.utils.import_utils import is_itrex_available


os.environ["CUDA_VISIBLE_DEVICES"] = ""
Expand All @@ -52,7 +55,7 @@


MODEL_NAMES_TO_TASK = (
("hf-internal-testing/tiny-random-gpt2", "text-generation"),
("hf-internal-testing/tiny-random-GPT2LMHeadModel", "text-generation"),
("hf-internal-testing/tiny-random-BertForMaskedLM", "fill-mask"),
("hf-internal-testing/tiny-random-DistilBertForSequenceClassification", "text-classification"),
("hf-internal-testing/tiny-random-DebertaV2Model", "feature-extraction"),
Expand Down Expand Up @@ -86,7 +89,7 @@ def test_compare_to_transformers(self, model_id, task):
outputs = inc_model(**model_inputs)
with tempfile.TemporaryDirectory() as tmpdirname:
inc_model.save_pretrained(tmpdirname)
loaded_model = model_class.from_pretrained(tmpdirname, file_name=WEIGHTS_NAME)
loaded_model = model_class.from_pretrained(tmpdirname)
outputs_loaded = loaded_model(**model_inputs)

if task == "feature-extraction":
Expand Down Expand Up @@ -143,3 +146,57 @@ def test_compare_with_and_without_past_key_values(self):
self.assertEqual(outputs_with_pkv.shape[1], self.GENERATION_LENGTH)
self.assertEqual(outputs_without_pkv.shape[1], self.GENERATION_LENGTH)
self.assertTrue(torch.equal(outputs_with_pkv, outputs_without_pkv))

@unittest.skipIf(not is_itrex_available(), reason="ITREX not available")
def test_saving_loading_woq_itrex_model(self):
model_name = "echarlaix/tiny-random-PhiForCausalLM"
subfolder = "itrex"
model = INCModelForCausalLM.from_pretrained(model_name, revision="itrex", subfolder=subfolder)
tokenizer = AutoTokenizer.from_pretrained(model_name, revision="itrex")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokens = tokenizer("This is a sample output", return_tensors="pt")

with tempfile.TemporaryDirectory() as tmp_dir:
model_save_dir = Path(tmp_dir) / subfolder
model.save_pretrained(model_save_dir)
folder_contents = os.listdir(model_save_dir)
self.assertIn(SAFE_WEIGHTS_NAME, folder_contents)
self.assertIn(QUANTIZATION_CONFIG_NAME, folder_contents)
loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder)

with torch.no_grad():
outputs = model(**tokens)
loaded_outputs = loaded_model(**tokens)

self.assertTrue("logits" in loaded_outputs)
self.assertIsInstance(loaded_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in loaded_outputs)
self.assertIsInstance(loaded_outputs.past_key_values, tuple)
self.assertTrue(torch.allclose(outputs.logits, loaded_outputs.logits, atol=1e-5))

def test_saving_loading_inc_model(self):
model_name = "echarlaix/tiny-random-PhiForCausalLM"
subfolder = "inc"
model = INCModelForCausalLM.from_pretrained(model_name, revision="inc", subfolder=subfolder)
tokenizer = AutoTokenizer.from_pretrained(model_name, revision="inc")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokens = tokenizer("This is a sample output", return_tensors="pt")

with tempfile.TemporaryDirectory() as tmp_dir:
model_save_dir = Path(tmp_dir) / subfolder
model.save_pretrained(model_save_dir)
folder_contents = os.listdir(model_save_dir)
self.assertIn(WEIGHTS_NAME, folder_contents)
self.assertIn("inc_config.json", folder_contents)
loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir, subfolder=subfolder)
self.assertIsInstance(loaded_model.inc_config, INCConfig)

with torch.no_grad():
outputs = model(**tokens)
loaded_outputs = loaded_model(**tokens)

self.assertTrue("logits" in loaded_outputs)
self.assertIsInstance(loaded_outputs.logits, torch.Tensor)
self.assertTrue("past_key_values" in loaded_outputs)
self.assertIsInstance(loaded_outputs.past_key_values, tuple)
self.assertTrue(torch.allclose(outputs.logits, loaded_outputs.logits, atol=1e-5))
1 change: 0 additions & 1 deletion tests/neural_compressor/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from utils_tests import MODEL_NAMES, SEED, INCTestMixin, _generate_dataset
from optimum.intel.utils.import_utils import is_torch_version, is_itrex_available


from optimum.intel import (
INCConfig,
INCModelForCausalLM,
Expand Down
2 changes: 1 addition & 1 deletion tests/neural_compressor/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"electra": "hf-internal-testing/tiny-random-electra",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
Expand Down

0 comments on commit bfd0767

Please sign in to comment.