From 6bf5fbc4b4f46b10138744c266be966e55c7c19c Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 29 Jan 2024 11:04:29 +0100 Subject: [PATCH 01/14] Expose InferRequestWrapper class so it can be imported from elsewhere (#533) * Expose InferRequestWrapper class so it can be imported from elsewhere * Fix --- optimum/intel/openvino/quantization.py | 74 ++++++++++++++------------ 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 9af0b9c9a6..cf816193c9 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -77,6 +77,44 @@ def batch_size(self): return batch_size +class InferRequestWrapper: + def __init__(self, request, data_cache=None): + self.request = request + if data_cache is None: + data_cache = [] + self.data_cache = data_cache + + def __call__(self, *args, **kwargs): + self.data_cache.append(*args) + return self.request(*args, **kwargs) + + def infer(self, inputs: Any = None, share_inputs: bool = False): + self.data_cache.append(inputs) + return self.request.infer(inputs, share_inputs) + + def start_async( + self, + inputs: Any = None, + userdata: Any = None, + share_inputs: bool = False, + *, + shared_memory: Any = None, + ): + self.data_cache.append(inputs) + self.request.infer(inputs, share_inputs, share_outputs=True) + + def wait(self): + pass + + def get_tensor(self, name: str): + return Tensor(self.request.results[name]) + + def __getattr__(self, attr): + if attr in self.__dict__: + return getattr(self, attr) + return getattr(self.request, attr) + + class OVQuantizer(OptimumQuantizer): """ Handle the NNCF quantization process. @@ -297,41 +335,7 @@ def _quantize_ovcausallm( subset_size = kwargs.get("subset_size", 300) data_cache = [] - class InferRequestWrapper: - def __init__(self, request): - self.request = request - - def __call__(self, *args, **kwargs): - data_cache.append(*args) - return self.request(*args, **kwargs) - - def infer(self, inputs: Any = None, share_inputs: bool = False): - data_cache.append(inputs) - return self.request.infer(inputs, share_inputs) - - def start_async( - self, - inputs: Any = None, - userdata: Any = None, - share_inputs: bool = False, - *, - shared_memory: Any = None, - ): - data_cache.append(inputs) - self.request.infer(inputs, share_inputs, share_outputs=True) - - def wait(self): - pass - - def get_tensor(self, name: str): - return Tensor(self.request.results[name]) - - def __getattr__(self, attr): - if attr in self.__dict__: - return getattr(self, attr) - return getattr(self.request, attr) - - self.model.request = InferRequestWrapper(self.model.request) + self.model.request = InferRequestWrapper(self.model.request, data_cache) for _, data in enumerate(calibration_dataloader): self.model.generate(**data, max_new_tokens=1) if len(data_cache) >= subset_size: From 6e79be1627133db56aa0b705088ed5055bae6928 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:55:16 +0100 Subject: [PATCH 02/14] Add IPEX models for audio and image classification tasks (#536) * add test * format * Add image classification task * Add test --- optimum/intel/__init__.py | 7 +- optimum/intel/generation/modeling.py | 4 +- optimum/intel/ipex/__init__.py | 3 + optimum/intel/ipex/modeling_base.py | 70 +++++++++-- optimum/intel/utils/dummy_ipex_objects.py | 33 +++++ tests/ipex/test_modeling.py | 142 ++++++++++++++++++++-- 6 files changed, 235 insertions(+), 24 deletions(-) diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 674e622003..4039b66a1a 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -48,9 +48,11 @@ "IPEXModelForMaskedLM", "IPEXModelForTokenClassification", "IPEXModelForQuestionAnswering", + "IPEXModelForImageClassification", + "IPEXModelForAudioClassification", + "IPEXModel", ] - try: if not (is_openvino_available() and is_nncf_available()): raise OptionalDependencyNotAvailable() @@ -159,7 +161,10 @@ from .utils.dummy_ipex_objects import * else: from .ipex import ( + IPEXModel, + IPEXModelForAudioClassification, IPEXModelForCausalLM, + IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, IPEXModelForSequenceClassification, diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 810dc0666f..0abdafe666 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -66,13 +66,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): model_inputs = prepare_jit_inputs(model, task, use_cache) - model.config.return_dict = False + model.config.return_dict = task not in {"text-generation", "audio-classification"} # check if the model_inputs is correct. model(**model_inputs) torch._C._jit_set_texpr_fuser_enabled(False) - if "past_key_values" in model_inputs.keys(): - model.config.return_dict = False if is_torch_version(">=", "2.1.0"): traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) else: diff --git a/optimum/intel/ipex/__init__.py b/optimum/intel/ipex/__init__.py index 90f183dcd7..a9ecc351b9 100644 --- a/optimum/intel/ipex/__init__.py +++ b/optimum/intel/ipex/__init__.py @@ -1,5 +1,8 @@ from optimum.intel.ipex.modeling_base import ( + IPEXModel, + IPEXModelForAudioClassification, IPEXModelForCausalLM, + IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, IPEXModelForSequenceClassification, diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index d7fb82ea98..a27f6ff8e3 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -25,7 +25,9 @@ from transformers import ( AutoConfig, AutoModel, + AutoModelForAudioClassification, AutoModelForCausalLM, + AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, @@ -68,6 +70,9 @@ def __init__( self.model.to(self._device) self.model_save_dir = model_save_dir + self.input_names = { + inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self" + } # Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating # a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 AutoConfig.register(self.base_model_prefix, AutoConfig) @@ -170,8 +175,22 @@ def _save_pretrained(self, save_directory: Union[str, Path]): output_path = os.path.join(save_directory, WEIGHTS_NAME) torch.jit.save(self.model, output_path) - def forward(self, *args, **kwargs): - outputs = self.model(*args, **kwargs) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + **kwargs, + ): + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + if "token_type_ids" in self.input_names: + inputs["token_type_ids"] = token_type_ids + + outputs = self.model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) def eval(self): @@ -196,14 +215,52 @@ class IPEXModelForSequenceClassification(IPEXModel): export_feature = "text-classification" +class IPEXModelForTokenClassification(IPEXModel): + auto_model_class = AutoModelForTokenClassification + export_feature = "token-classification" + + class IPEXModelForMaskedLM(IPEXModel): auto_model_class = AutoModelForMaskedLM export_feature = "fill-mask" -class IPEXModelForTokenClassification(IPEXModel): - auto_model_class = AutoModelForTokenClassification - export_feature = "token-classification" +class IPEXModelForImageClassification(IPEXModel): + auto_model_class = AutoModelForImageClassification + export_feature = "image-classification" + + def forward( + self, + pixel_values: torch.Tensor, + **kwargs, + ): + inputs = { + "pixel_values": pixel_values, + } + + outputs = self.model(**inputs) + return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) + + +class IPEXModelForAudioClassification(IPEXModel): + auto_model_class = AutoModelForAudioClassification + export_feature = "audio-classification" + + def forward( + self, + input_values: torch.Tensor, + attention_mask: torch.Tensor = None, + **kwargs, + ): + inputs = { + "input_values": input_values, + } + + if "attention_mask" in self.input_names: + inputs["attention_mask"] = attention_mask + + outputs = self.model(**inputs) + return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) class IPEXModelForQuestionAnswering(IPEXModel): @@ -233,9 +290,6 @@ def __init__( self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.model_dtype = kwargs.get("model_dtype", None) - self.input_names = { - inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self" - } self.use_cache = "past_key_values" in self.input_names if use_cache ^ self.use_cache: diff --git a/optimum/intel/utils/dummy_ipex_objects.py b/optimum/intel/utils/dummy_ipex_objects.py index fd1a662b67..c451dd3956 100644 --- a/optimum/intel/utils/dummy_ipex_objects.py +++ b/optimum/intel/utils/dummy_ipex_objects.py @@ -22,6 +22,17 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["ipex"]) +class IPEXModel(metaclass=DummyObject): + _backends = ["ipex"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["ipex"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["ipex"]) + + class IPEXModelForSequenceClassification(metaclass=DummyObject): _backends = ["ipex"] @@ -75,3 +86,25 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["ipex"]) + + +class IPEXModelForImageClassification(metaclass=DummyObject): + _backends = ["ipex"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["ipex"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["ipex"]) + + +class IPEXModelForAudioClassification(metaclass=DummyObject): + _backends = ["ipex"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["ipex"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["ipex"]) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 22ef5ee348..03a2f5de5a 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -16,12 +16,15 @@ import time import unittest +import numpy as np +import requests import torch from parameterized import parameterized +from PIL import Image from transformers import ( + AutoFeatureExtractor, AutoModelForCausalLM, AutoModelForQuestionAnswering, - AutoModelForSequenceClassification, AutoTokenizer, PretrainedConfig, pipeline, @@ -29,13 +32,23 @@ ) from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS -from optimum.intel import IPEXModelForCausalLM, IPEXModelForQuestionAnswering, IPEXModelForSequenceClassification +from optimum.intel import ( + IPEXModel, + IPEXModelForAudioClassification, + IPEXModelForCausalLM, + IPEXModelForImageClassification, + IPEXModelForMaskedLM, + IPEXModelForQuestionAnswering, + IPEXModelForSequenceClassification, + IPEXModelForTokenClassification, +) SEED = 42 MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-albert", + "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", "bert": "hf-internal-testing/tiny-random-bert", "bart": "hf-internal-testing/tiny-random-bart", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", @@ -43,6 +56,7 @@ "bloom": "hf-internal-testing/tiny-random-BloomModel", "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", + "convnext": "hf-internal-testing/tiny-random-convnext", "distilbert": "hf-internal-testing/tiny-random-distilbert", "electra": "hf-internal-testing/tiny-random-electra", "flaubert": "hf-internal-testing/tiny-random-flaubert", @@ -51,17 +65,25 @@ "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "levit": "hf-internal-testing/tiny-random-LevitModel", "llama": "fxmarty/tiny-llama-fast-tokenizer", "opt": "hf-internal-testing/tiny-random-OPTModel", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "mistral": "echarlaix/tiny-random-mistral", + "mobilenet_v1": "google/mobilenet_v1_0.75_192", + "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", + "mobilevit": "hf-internal-testing/tiny-random-mobilevit", "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "stas/mt5-tiny-random", + "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", "squeezebert": "hf-internal-testing/tiny-random-squeezebert", "t5": "hf-internal-testing/tiny-random-t5", + "unispeech": "hf-internal-testing/tiny-random-unispeech", + "vit": "hf-internal-testing/tiny-random-vit", + "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", "xlm": "hf-internal-testing/tiny-random-xlm", } @@ -75,11 +97,10 @@ def __exit__(self, type, value, traceback): self.elapsed = (time.perf_counter() - self.elapsed) * 1e3 -class IPEXModelForSequenceClassificationTest(unittest.TestCase): +class IPEXModelTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "albert", "bert", - "convbert", "distilbert", "electra", "flaubert", @@ -89,13 +110,15 @@ class IPEXModelForSequenceClassificationTest(unittest.TestCase): "xlm", ) + IPEX_MODEL_CLASS = IPEXModel + @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = AutoModelForSequenceClassification.from_pretrained(model_id) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = "This is a sample input" tokens = tokenizer(inputs, return_tensors="pt") @@ -103,20 +126,34 @@ def test_compare_to_transformers(self, model_arch): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) # Compare tensor outputs - self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + for output_name in {"logits", "last_hidden_state"}: + if output_name in transformers_outputs: + self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) tokenizer = AutoTokenizer.from_pretrained(model_id) - pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) + pipe = pipeline(self.IPEX_MODEL_CLASS.export_feature, model=model, tokenizer=tokenizer) text = "This restaurant is awesome" - outputs = pipe(text) + if self.IPEX_MODEL_CLASS.export_feature == "fill-mask": + text += tokenizer.mask_token + _ = pipe(text) self.assertEqual(pipe.device, model.device) - self.assertGreaterEqual(outputs[0]["score"], 0.0) - self.assertIsInstance(outputs[0]["label"], str) + + +class IPEXModelForSequenceClassificationTest(IPEXModelTest): + IPEX_MODEL_CLASS = IPEXModelForTokenClassification + + +class IPEXModelForTokenClassificationTest(IPEXModelTest): + IPEX_MODEL_CLASS = IPEXModelForSequenceClassification + + +class IPEXModelForMaskedLMTest(IPEXModelTest): + IPEX_MODEL_CLASS = IPEXModelForMaskedLM class IPEXModelForQuestionAnsweringTest(unittest.TestCase): @@ -245,3 +282,84 @@ def test_compare_with_and_without_past_key_values(self): f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", ) + + +class IPEXModelForAudioClassificationTest(unittest.TestCase): + IPEX_MODEL_CLASS = IPEXModelForAudioClassification + SUPPORTED_ARCHITECTURES = ( + "unispeech", + "wav2vec2", + ) + + def _generate_random_audio_data(self): + np.random.seed(10) + t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + # generate pure sine wave at 220 Hz + audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + self.assertIsInstance(ipex_model.config, PretrainedConfig) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + inputs = preprocessor(self._generate_random_audio_data(), return_tensors="pt") + with torch.no_grad(): + transformers_outputs = transformers_model(**inputs) + outputs = ipex_model(**inputs) + # Compare tensor outputs + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + pipe = pipeline("audio-classification", model=model, feature_extractor=preprocessor) + outputs = pipe([np.random.random(16000)]) + self.assertEqual(pipe.device, model.device) + self.assertTrue(all(item["score"] > 0.0 for item in outputs[0])) + + +class IPEXModelForImageClassificationIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( + "beit", + # "levit", + "mobilenet_v1", + "mobilenet_v2", + "mobilevit", + "resnet", + "vit", + ) + IPEX_MODEL_CLASS = IPEXModelForImageClassification + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + self.assertIsInstance(ipex_model.config, PretrainedConfig) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + inputs = preprocessor(images=image, return_tensors="pt") + with torch.no_grad(): + transformers_outputs = transformers_model(**inputs) + outputs = ipex_model(**inputs) + self.assertIn("logits", outputs) + # Compare tensor outputs + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor) + outputs = pipe("http://images.cocodataset.org/val2017/000000039769.jpg") + self.assertEqual(pipe.device, model.device) + self.assertGreaterEqual(outputs[0]["score"], 0.0) + self.assertTrue(isinstance(outputs[0]["label"], str)) From 20df723cbf26fbb1cffcaeb3391d9b6bf74c3095 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 30 Jan 2024 14:43:20 +0400 Subject: [PATCH 03/14] =?UTF-8?q?relax=20requirements=20to=20have=20regist?= =?UTF-8?q?ered=20normalized=20config=20for=20usage=20con=E2=80=A6=20(#537?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * relax requirements to have registered normalized config for usage converted decoder models * add property for access to normalized config --- optimum/exporters/openvino/stateful.py | 5 +---- optimum/intel/openvino/modeling_decoder.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index e6ec1879a5..c90c2211ed 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -22,7 +22,6 @@ from openvino.runtime import opset13 from optimum.exporters import TasksManager from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version -from optimum.utils.normalized_config import NormalizedConfigManager def model_has_state(ov_model: ov.Model): @@ -217,9 +216,7 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): batch_dim = 1 if config.model_type == "chatglm" else 0 fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim) - - normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - num_attention_heads = normalized_config.num_attention_heads if config.model_type == "bloom" else 1 + num_attention_heads = config.num_attention_heads if config.model_type == "bloom" else 1 make_stateful( ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None ) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 1644022c29..64135266b3 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -27,7 +27,7 @@ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import CausalLMOutputWithPast -from optimum.utils import NormalizedConfigManager +from optimum.utils.normalized_config import NormalizedConfigManager from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful from ...exporters.openvino.stateful import model_has_state @@ -132,7 +132,6 @@ def __init__( self.stateful = model_has_sinks self.main_input_name = "input_ids" self.num_pkv = 2 - self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.key_value_input_names = [key for key in self.input_names if "key_values" in key] self.key_value_output_names = [key for key in self.output_names if "present" in key] self._original_model = self.model.clone() # keep original model for serialization @@ -321,6 +320,13 @@ def reshape(self, batch_size: int, sequence_length: int): logger.warning("Static shapes are not supported for causal language model.") return self + @property + def normalized_config(self): + logger.warning( + "access to normalized_config attribute is deprecated and will be removed in future versions, please use config" + ) + return NormalizedConfigManager.get_normalized_config_class(self.config.model_type)(self.config) + def compile(self): if self.request is None: super().compile() @@ -364,7 +370,7 @@ def forward( batch_size = input_ids.shape[0] if self.config.model_type == "bloom": - batch_size *= self.normalized_config.num_attention_heads + batch_size *= self.config.num_attention_heads inputs = {} past_len = 0 @@ -592,8 +598,8 @@ def _reorder_cache( if self.stateful: beam_idx = np.array(beam_idx) batch_size = beam_idx.shape[0] - indices = np.array(range(batch_size * self.normalized_config.num_attention_heads)) - indices = indices.reshape([batch_size, self.normalized_config.num_attention_heads]) + indices = np.array(range(batch_size * self.config.num_attention_heads)) + indices = indices.reshape([batch_size, self.config.num_attention_heads]) self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() return past_key_values else: From 1b5c3cb74421e199438d127ed2935db25c12787a Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:10:39 +0100 Subject: [PATCH 04/14] IPEX decoder model fix (#539) --- optimum/intel/ipex/inference.py | 11 ++--------- optimum/intel/ipex/modeling_base.py | 2 +- tests/ipex/test_inference.py | 4 +++- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/optimum/intel/ipex/inference.py b/optimum/intel/ipex/inference.py index d246e4a1cf..25145a6997 100644 --- a/optimum/intel/ipex/inference.py +++ b/optimum/intel/ipex/inference.py @@ -36,20 +36,13 @@ IPEXOPTForCausalLM, IPEXGPTBigCodeForCausalLM, IPEXModelForQuestionAnswering, + _MODEL_TYPE_TO_AUTOMODELS, ) from .utils import _HEAD_TO_AUTOMODELS -_MODEL_TYPE_TO_AUTOMODELS = { - "bloom": IPEXBloomForCausalLM, - "mpt": IPEXMPTForCausalLM, - "opt": IPEXOPTForCausalLM, - "big_code": IPEXGPTBigCodeForCausalLM, -} - - logger = logging.getLogger(__name__) IPEX_NOT_AVAILABLE_ERROR_MSG = ( @@ -149,7 +142,7 @@ def __enter__(self): model_type = getattr(self._original.config, "model_type", "").replace("_", "-") if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys(): - auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[task] + auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[model_type] else: auto_model_class = eval(_HEAD_TO_AUTOMODELS[task]) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a27f6ff8e3..a522fef265 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -600,5 +600,5 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "bloom": IPEXBloomForCausalLM, "mpt": IPEXMPTForCausalLM, "opt": IPEXOPTForCausalLM, - "big-code": IPEXGPTBigCodeForCausalLM, + "gpt-bigcode": IPEXGPTBigCodeForCausalLM, } diff --git a/tests/ipex/test_inference.py b/tests/ipex/test_inference.py index c95467db64..7d9f862ef1 100644 --- a/tests/ipex/test_inference.py +++ b/tests/ipex/test_inference.py @@ -42,6 +42,8 @@ "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "llama": "fxmarty/tiny-llama-fast-tokenizer", + "opt": "hf-internal-testing/tiny-random-OPTModel", + "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", } _CLASSIFICATION_TASK_TO_AUTOMODELS = { @@ -57,7 +59,7 @@ class IPEXIntegrationTest(unittest.TestCase): "roberta", ) - TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("gptj", "gpt2", "gpt_neo", "gpt_bigcode", "llama") + TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("gptj", "gpt2", "gpt_neo", "gpt_bigcode", "llama", "opt", "mpt") QA_SUPPORTED_ARCHITECTURES = ( "bert", From 3b627f4252cc48473113e7571a50476a825e0379 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:57:34 +0100 Subject: [PATCH 05/14] Enable loading of torchscript model with INC and add warning (#540) --- .../intel/neural_compressor/modeling_base.py | 53 +++++++++++++++---- tests/neural_compressor/test_modeling.py | 21 ++++++++ 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index b74c08a573..72646a9f94 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -40,6 +40,8 @@ from transformers.models.auto.auto_factory import _get_model_class from transformers.utils.generic import ContextManagers +from optimum.intel.generation import BaseModelForCausalLM + from ...modeling_base import OptimizedModel from ..utils.import_utils import _torch_version, is_torch_version from .configuration import INCConfig @@ -83,11 +85,6 @@ def __init__( "cuda:0" if torch.cuda.is_available() else "cpu" ) - if getattr(self.config, "backend", None) == "ipex": - raise NotImplementedError( - "`INCModel` does not supported the loading of model resulting from IPEX, please use `IPEXModel` to load your model instead instead" - ) - # Registers the INCModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating # a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 AutoConfig.register(self.base_model_prefix, AutoConfig) @@ -143,11 +140,19 @@ def _from_pretrained( f"Please check if torch quantization the model was obtained with is compatible with {_torch_version}." ) + if getattr(config, "backend", None) == "ipex" or getattr(config, "torchscript", False): + logger.warning( + f"Using `{cls.__name__}` to load a TorchScript model will be deprecated in v1.15.0, to load your model please use `{cls.__name__.replace('INC', 'IPEX')}` instead." + ) + model = torch.jit.load(model_cache_path) + model = torch.jit.freeze(model.eval()) + return cls(model, config=config, model_save_dir=model_save_dir, inc_config=inc_config, **kwargs) + model_class = _get_model_class(config, cls.auto_model_class._model_mapping) # Load the state dictionary of the model to verify whether the model to get the quantization config state_dict = torch.load(model_cache_path, map_location="cpu") - q_config = state_dict.get("best_configure", None) + q_config = state_dict.get("best_configure", None) if q_config is None: model = model_class.from_pretrained(model_save_dir) else: @@ -169,10 +174,13 @@ def _from_pretrained( def _save_pretrained(self, save_directory: Union[str, Path]): output_path = os.path.join(save_directory, WEIGHTS_NAME) - state_dict = self.model.state_dict() - if self._q_config: - state_dict["best_configure"] = self._q_config - torch.save(state_dict, output_path) + if isinstance(self.model, torch.nn.Module): + state_dict = self.model.state_dict() + if self._q_config: + state_dict["best_configure"] = self._q_config + torch.save(state_dict, output_path) + else: + torch.jit.save(self.model, output_path) if self.inc_config: self.inc_config.save_pretrained(save_directory) @@ -244,6 +252,29 @@ class INCModelForXLNetLM(INCModel): export_feature = "fill-mask" -class INCModelForCausalLM(INCModel): +class INCModelForCausalLM(INCModel, BaseModelForCausalLM): auto_model_class = AutoModelForCausalLM export_feature = "text-generation" + forward = BaseModelForCausalLM.forward + generate = BaseModelForCausalLM.generate + can_generate = BaseModelForCausalLM.can_generate + + def __init__( + self, + model, + config: PretrainedConfig = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + q_config: Dict = None, + inc_config: Dict = None, + use_cache: bool = True, + **kwargs, + ): + super(INCModelForCausalLM, self).__init__( + model=model, + config=config, + model_save_dir=model_save_dir, + q_config=q_config, + inc_config=inc_config, + use_cache=use_cache, + **kwargs, + ) diff --git a/tests/neural_compressor/test_modeling.py b/tests/neural_compressor/test_modeling.py index e0a41e76af..e6ce4763f2 100644 --- a/tests/neural_compressor/test_modeling.py +++ b/tests/neural_compressor/test_modeling.py @@ -122,3 +122,24 @@ def test_pipeline(self, model_id, task): inputs *= 2 pipe(*inputs) + + def test_compare_with_and_without_past_key_values(self): + model_id = "echarlaix/tiny-random-gpt2-torchscript" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + + model_with_pkv = INCModelForCausalLM.from_pretrained(model_id, use_cache=True, subfolder="model_with_pkv") + + outputs_with_pkv = model_with_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = INCModelForCausalLM.from_pretrained( + model_id, use_cache=False, subfolder="model_without_pkv" + ) + + outputs_without_pkv = model_without_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertEqual(outputs_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue(torch.equal(outputs_with_pkv, outputs_without_pkv)) From a25142276e828baa1abc2cdde23bb92e57de61cb Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Wed, 31 Jan 2024 11:54:52 +0100 Subject: [PATCH 06/14] Fix torch version for ipex tests (#545) * fix torch version for ipex tests * disbale tests for incompatible torch version with ipex * fix --- .github/workflows/test_ipex.yml | 1 + tests/neural_compressor/test_optimization.py | 4 ++++ tests/neural_compressor/utils_tests.py | 17 +++++++++++------ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index 82c9e8c7f7..6b683d720b 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -30,6 +30,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cpu pip install .[ipex,tests] - name: Test with Pytest run: | diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index 3a7717d17a..61a54128b4 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -18,6 +18,7 @@ import os import tempfile +import unittest import evaluate import numpy as np import torch @@ -43,6 +44,8 @@ set_seed, ) from utils_tests import SEED, INCTestMixin, _generate_dataset +from optimum.intel.utils.import_utils import is_torch_version + from optimum.intel import ( INCConfig, @@ -163,6 +166,7 @@ def test_static_quantization(self, task, model_name, expected_quantized_matmuls) ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) + @unittest.skipIf(is_torch_version(">=", "2.2.0"), "compatibility issue with torch 2.2.0 and IPEX latest version") def test_ipex_static_quantization_with_smoothquant(self, task, model_name, expected_quantized_matmuls): recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": 0.5}} num_samples = 10 diff --git a/tests/neural_compressor/utils_tests.py b/tests/neural_compressor/utils_tests.py index a429ce6dd1..214aa73be5 100644 --- a/tests/neural_compressor/utils_tests.py +++ b/tests/neural_compressor/utils_tests.py @@ -40,18 +40,23 @@ INCStableDiffusionPipeline, ) -from optimum.intel.ipex import ( - IPEXModelForCausalLM, - IPEXModelForSequenceClassification, - IPEXModelForMaskedLM, - IPEXModelForTokenClassification, -) + +from optimum.intel.utils.import_utils import is_torch_version from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS from optimum.intel.utils.constant import ONNX_WEIGHTS_NAME from optimum.onnxruntime import ORTModelForCausalLM, ORTModelForSequenceClassification from optimum.pipelines import ORT_SUPPORTED_TASKS +if is_torch_version("<", "2.2.0"): + from optimum.intel.ipex import ( + IPEXModelForCausalLM, + IPEXModelForSequenceClassification, + IPEXModelForMaskedLM, + IPEXModelForTokenClassification, + ) + + SEED = 1009 _TASK_TO_DATASET = { "text-classification": ("glue", "sst2", "sentence"), From 398450d8db6bfe495004469d6cac16ea4771d269 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 13:23:54 +0200 Subject: [PATCH 07/14] Refactor IPEX CausalLM for better model architecture scale (#544) * Refactor IPEX CausalLM for better model arch scale * Fix style --- optimum/intel/ipex/inference.py | 13 +- optimum/intel/ipex/modeling_base.py | 244 ++-------------------------- 2 files changed, 16 insertions(+), 241 deletions(-) diff --git a/optimum/intel/ipex/inference.py b/optimum/intel/ipex/inference.py index 25145a6997..ccf2da9d80 100644 --- a/optimum/intel/ipex/inference.py +++ b/optimum/intel/ipex/inference.py @@ -31,12 +31,7 @@ IPEXModelForMaskedLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, - IPEXBloomForCausalLM, - IPEXMPTForCausalLM, - IPEXOPTForCausalLM, - IPEXGPTBigCodeForCausalLM, IPEXModelForQuestionAnswering, - _MODEL_TYPE_TO_AUTOMODELS, ) @@ -139,13 +134,7 @@ def __enter__(self): ) if task in _HEAD_TO_AUTOMODELS: model = jit_trace(model, task, use_cache) - model_type = getattr(self._original.config, "model_type", "").replace("_", "-") - - if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys(): - auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[model_type] - else: - auto_model_class = eval(_HEAD_TO_AUTOMODELS[task]) - + auto_model_class = eval(_HEAD_TO_AUTOMODELS[task]) model = auto_model_class(model, self._original.config, use_cache=use_cache) # Enable automatic mixed precision (AMP) if we are going to target `bfloat16` diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a522fef265..b79f720348 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -36,7 +36,9 @@ GenerationMixin, PretrainedConfig, ) +from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput +from transformers.models.auto.auto_factory import _get_model_class as get_model_class from transformers.utils import WEIGHTS_NAME from optimum.exporters import TasksManager @@ -164,12 +166,8 @@ def _from_pretrained( model = torch.jit.load(model_cache_path) torch.jit.freeze(model.eval()) - model_type = config.model_type.replace("_", "-") - init_cls = cls - if cls.export_feature == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS: - init_cls = _MODEL_TYPE_TO_AUTOMODELS[model_type] - return init_cls(model, config=config, model_save_dir=model_save_dir, **kwargs) + return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) def _save_pretrained(self, save_directory: Union[str, Path]): output_path = os.path.join(save_directory, WEIGHTS_NAME) @@ -302,6 +300,18 @@ def __init__( config.is_decoder = True config.is_encoder_decoder = False self.generation_config = GenerationConfig.from_model_config(config) + try: + self.model_cls = get_class_from_dynamic_module( + self.config.auto_map["AutoModelForCausalLM"], model_save_dir + ) + except AttributeError: + self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) + self._reorder_cache = self.model_cls._reorder_cache.__get__(self) + self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) + if hasattr(self.model_cls, "_convert_to_standard_cache"): + self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache + if hasattr(self.model_cls, "_convert_to_bloom_cache"): + self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") @@ -378,227 +388,3 @@ def forward( past_key_values = outputs["past_key_values"] if self.use_cache else None return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) - - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - position_ids = kwargs.get("position_ids", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": position_ids, - "attention_mask": attention_mask, - } - - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - - -class IPEXGPTBigCodeForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - # Omit tokens covered by past_key_values - if past_key_values: - if self.config.multi_query: - past_length = past_key_values[0].shape[1] - else: - past_length = past_key_values[0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - else: - position_ids = None - - model_inputs = {"input_ids": input_ids} - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - } - ) - return model_inputs - - -class IPEXBloomForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - # only last token for input_ids if past is not None - if past_key_values: - # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache - @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: - standardized_past = IPEXModelForCausalLM._convert_to_standard_cache(past, batch_size=len(beam_idx)) - - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in standardized_past - ) - return IPEXModelForCausalLM._convert_to_bloom_cache(reordered_past) - - @staticmethod - def _convert_to_standard_cache( - past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int - ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: - """ - Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, - num_heads, ...])) - """ - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape - num_heads = batch_size_times_num_heads // batch_size - # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] - # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size, num_heads, head_dim, seq_length), - layer_past[1].view(batch_size, num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - @staticmethod - def _convert_to_bloom_cache( - past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]] - ) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: - """ - Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) - """ - batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape - batch_size_times_num_heads = batch_size * num_heads - # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] - # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), - layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - -class IPEXOPTForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - -class IPEXMPTForCausalLM(IPEXModelForCausalLM): - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] - - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - -_MODEL_TYPE_TO_AUTOMODELS = { - "bloom": IPEXBloomForCausalLM, - "mpt": IPEXMPTForCausalLM, - "opt": IPEXOPTForCausalLM, - "gpt-bigcode": IPEXGPTBigCodeForCausalLM, -} From 8ee487dc2ade5bd0023d1bbe0a0103d6af8821e0 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 13:42:14 +0200 Subject: [PATCH 08/14] Automatic `torch.autocast` for IPEXModel (#542) * Handle autocast in IPEXModel.forward * Handle missing torch_dtype in config --- optimum/intel/ipex/modeling_base.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index b79f720348..901e90a421 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -69,6 +69,7 @@ def __init__( OptimizedModel.__init__(self, model=model, config=config) # To do: add XPU support self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model.to(self._device) self.model_save_dir = model_save_dir @@ -188,7 +189,7 @@ def forward( if "token_type_ids" in self.input_names: inputs["token_type_ids"] = token_type_ids - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) def eval(self): @@ -199,6 +200,10 @@ def eval(self): def device(self) -> torch.device: return self._device + @property + def dtype(self) -> torch.dtype: + return self._dtype + def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) @@ -207,6 +212,14 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) + def _call_model(self, *args, **kwargs): + try: + with torch.autocast(self.device.type, self.dtype): + out = self.model(*args, **kwargs) + except RuntimeError: + out = self.model(*args, **kwargs) + return out + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -236,7 +249,7 @@ def forward( "pixel_values": pixel_values, } - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -257,7 +270,7 @@ def forward( if "attention_mask" in self.input_names: inputs["attention_mask"] = attention_mask - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -266,7 +279,7 @@ class IPEXModelForQuestionAnswering(IPEXModel): export_feature = "question-answering" def forward(self, *args, **kwargs): - outputs = self.model(*args, **kwargs) + outputs = self._call_model(*args, **kwargs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -287,7 +300,7 @@ def __init__( super().__init__(model, config, model_save_dir=model_save_dir) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - self.model_dtype = kwargs.get("model_dtype", None) + self.model_dtype = kwargs.get("model_dtype", self.dtype) self.use_cache = "past_key_values" in self.input_names if use_cache ^ self.use_cache: @@ -377,7 +390,7 @@ def forward( inputs["past_key_values"] = past_key_values # 2. Model forward - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) # 3. Process model outputs if isinstance(outputs, (list, tuple)): From 788e4583ae2455a8af05bfa7682dbce163511e0e Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Wed, 31 Jan 2024 18:38:57 +0200 Subject: [PATCH 09/14] Add an initial warmup step to `IPEXModel`s (#543) * Handle autocast in IPEXModel.forward * Handle missing torch_dtype in config * Warmup IPEX models at init * Minor fix * Fix _init_warmup use_cache condition * Fix output handling in IPEX question answering --- optimum/intel/ipex/modeling_base.py | 37 +++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 901e90a421..2f7267c984 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -15,6 +15,7 @@ import logging import os +from functools import wraps from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union @@ -45,7 +46,7 @@ from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ..generation.modeling import jit_trace +from ..generation.modeling import jit_trace, prepare_jit_inputs from ..utils.import_utils import is_torch_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask @@ -64,6 +65,7 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + warmup: bool = True, **kwargs, ): OptimizedModel.__init__(self, model=model, config=config) @@ -81,6 +83,8 @@ def __init__( AutoConfig.register(self.base_model_prefix, AutoConfig) if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) + if warmup: + self._init_warmup() @classmethod def _from_transformers( @@ -220,6 +224,14 @@ def _call_model(self, *args, **kwargs): out = self.model(*args, **kwargs) return out + def _init_warmup(self): + # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and + # the results of the compute are unpredictable + use_cache = "past_key_values" in self.input_names + dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) + for _ in range(2): + self(**dummy_inputs) + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -278,8 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" - def forward(self, *args, **kwargs): - outputs = self._call_model(*args, **kwargs) + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + **kwargs, + ): + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + if "token_type_ids" in self.input_names: + inputs["token_type_ids"] = token_type_ids + + outputs = self._call_model(**inputs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -295,9 +320,11 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: bool = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir) + # Perform the initial warmup at the end of __init__ + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.model_dtype = kwargs.get("model_dtype", self.dtype) @@ -325,6 +352,8 @@ def __init__( self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache + if warmup: + self._init_warmup() def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") From 0ca944716f0819e27a9cb71d830846970ab21745 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Wed, 31 Jan 2024 17:44:12 +0100 Subject: [PATCH 10/14] Fix format (#546) --- optimum/intel/ipex/modeling_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 2f7267c984..67810ae067 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -15,7 +15,6 @@ import logging import os -from functools import wraps from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union @@ -290,7 +289,8 @@ class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" - def forward(self, + def forward( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None, From 552de65a9c5f7fa1a2f0ce6859ebdeedaeaabe53 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 31 Jan 2024 18:42:12 +0100 Subject: [PATCH 11/14] Dev version --- optimum/intel/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/version.py b/optimum/intel/version.py index a00a083418..5945cc1e8c 100644 --- a/optimum/intel/version.py +++ b/optimum/intel/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.14.0.dev0" +__version__ = "1.15.0.dev0" From 7ea3656b84a325d7fba28e786592697905d35d8e Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 1 Feb 2024 17:29:33 +0100 Subject: [PATCH 12/14] Fix OV pre-commit test --- tests/openvino/test_training.py | 44 ++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py index d932b7ff63..c306d2d530 100644 --- a/tests/openvino/test_training.py +++ b/tests/openvino/test_training.py @@ -276,7 +276,7 @@ def tearDown(self): shutil.rmtree(self.output_dir) -CUSTOMIZED_QUANTIZATION_CONFIG = { +TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG = { "algorithm": "quantization", "overflow_fix": "disable", "initializer": { @@ -288,7 +288,11 @@ def tearDown(self): "batchnorm_adaptation": {"num_bn_adaptation_samples": 4}, }, "scope_overrides": {"activations": {"{re}.*matmul_0": {"mode": "asymmetric"}}}, - "ignored_scopes": [], + "ignored_scopes": [ + "BertForSequenceClassification/BertModel[bert]/__rsub___0", + "BertForSequenceClassification/BertModel[bert]/__mul___0", + "{re}BertLayer\\[[0-9]+\\]/BertAttention\\[attention\\]/BertSelfAttention\\[self\\]/__add___0", + ], } STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT = { @@ -335,16 +339,16 @@ def tearDown(self): ), "customized_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=69, + nncf_compression_config=TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + expected_fake_quantize=64, expected_int8=35, compression_metrics=["compression_loss"], ), "distillation,customized_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=69, + nncf_compression_config=TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + expected_fake_quantize=64, expected_int8=35, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), @@ -371,8 +375,11 @@ def tearDown(self): ), "customized_quantization,structured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=69, + nncf_compression_config=[ + TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, + ], + expected_fake_quantize=64, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -389,8 +396,11 @@ def tearDown(self): "distillation,customized_quantization,structured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=69, + nncf_compression_config=[ + TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, + ], + expected_fake_quantize=64, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], @@ -418,8 +428,11 @@ def tearDown(self): ), "customized_quantization,unstructured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=69, + nncf_compression_config=[ + TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, + ], + expected_fake_quantize=64, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss"], @@ -436,8 +449,11 @@ def tearDown(self): "distillation,customized_quantization,unstructured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=[CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT], - expected_fake_quantize=69, + nncf_compression_config=[ + TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, + ], + expected_fake_quantize=64, expected_int8=35, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], From 24f40bf2c10e3b911b6d34f27435558e31a8f278 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 2 Feb 2024 14:32:39 +0100 Subject: [PATCH 13/14] CUSTOMIZED_QUANTIZATION_CONFIG is updated --- tests/openvino/test_training.py | 67 ++++++++++++++++----------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py index c306d2d530..aaffdcdc92 100644 --- a/tests/openvino/test_training.py +++ b/tests/openvino/test_training.py @@ -276,24 +276,21 @@ def tearDown(self): shutil.rmtree(self.output_dir) -TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG = { - "algorithm": "quantization", - "overflow_fix": "disable", - "initializer": { - "range": { - "num_init_samples": 16, - "type": "percentile", - "params": {"min_percentile": 0.01, "max_percentile": 99.99}, +CUSTOMIZED_QUANTIZATION_CONFIG = deepcopy(DEFAULT_QUANTIZATION_CONFIG) +CUSTOMIZED_QUANTIZATION_CONFIG.update( + { + "overflow_fix": "disable", + "initializer": { + "range": { + "num_init_samples": 16, + "type": "percentile", + "params": {"min_percentile": 0.01, "max_percentile": 99.99}, + }, + "batchnorm_adaptation": {"num_bn_adaptation_samples": 4}, }, - "batchnorm_adaptation": {"num_bn_adaptation_samples": 4}, - }, - "scope_overrides": {"activations": {"{re}.*matmul_0": {"mode": "asymmetric"}}}, - "ignored_scopes": [ - "BertForSequenceClassification/BertModel[bert]/__rsub___0", - "BertForSequenceClassification/BertModel[bert]/__mul___0", - "{re}BertLayer\\[[0-9]+\\]/BertAttention\\[attention\\]/BertSelfAttention\\[self\\]/__add___0", - ], -} + "scope_overrides": {"activations": {"{re}.*matmul_0": {"mode": "asymmetric"}}}, + } +) STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT = { "algorithm": "movement_sparsity", @@ -339,17 +336,17 @@ def tearDown(self): ), "customized_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=64, - expected_int8=35, + nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, + expected_fake_quantize=44, + expected_int8=32, compression_metrics=["compression_loss"], ), "distillation,customized_quantization": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", - nncf_compression_config=TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, - expected_fake_quantize=64, - expected_int8=35, + nncf_compression_config=CUSTOMIZED_QUANTIZATION_CONFIG, + expected_fake_quantize=44, + expected_int8=32, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), "structured_movement_sparsity": OVTrainerTestDescriptor( @@ -376,11 +373,11 @@ def tearDown(self): "customized_quantization,structured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[ - TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + CUSTOMIZED_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, ], - expected_fake_quantize=64, - expected_int8=35, + expected_fake_quantize=44, + expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss"], ), @@ -397,11 +394,11 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[ - TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + CUSTOMIZED_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, ], - expected_fake_quantize=64, - expected_int8=35, + expected_fake_quantize=44, + expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), @@ -429,11 +426,11 @@ def tearDown(self): "customized_quantization,unstructured_movement_sparsity": OVTrainerTestDescriptor( model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[ - TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, ], - expected_fake_quantize=64, - expected_int8=35, + expected_fake_quantize=44, + expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss"], ), @@ -450,11 +447,11 @@ def tearDown(self): model_id="hf-internal-testing/tiny-random-bert", teacher_model_id="hf-internal-testing/tiny-random-bert", nncf_compression_config=[ - TINY_RANDOM_BERT_CUSTOMIZED_QUANTIZATION_CONFIG, + CUSTOMIZED_QUANTIZATION_CONFIG, UNSTRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT, ], - expected_fake_quantize=64, - expected_int8=35, + expected_fake_quantize=44, + expected_int8=32, expected_binary_masks=60, compression_metrics=["compression_loss", "distillation_loss", "task_loss"], ), From 0f45751f19c372f3242c5d64e3bc40bc61ae2c07 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Mon, 5 Feb 2024 11:06:41 +0100 Subject: [PATCH 14/14] Update README (#549) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6bcf786b58..85f50d24e7 100644 --- a/README.md +++ b/README.md @@ -78,10 +78,10 @@ It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2 optimum-cli export openvino --model gpt2 ov_model ``` -If you add `--int8`, the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision. +You can also apply 8-bit weight-only quantization when exporting your model : the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision. ```plain -optimum-cli export openvino --model gpt2 --int8 ov_model +optimum-cli export openvino --model gpt2 --weight-format int8 ov_model ``` To apply quantization on both weights and activations, you can find more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov).