From f87f93fe554240eb69394e438178bea61b96973a Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 8 Sep 2023 18:44:04 +0400 Subject: [PATCH 1/8] Added hybrid quantization for seq2seq models --- optimum/intel/openvino/quantization.py | 48 ++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 99e22e72f5..df7f92930d 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -41,6 +41,7 @@ from .configuration import INT8_WEIGHT_COMPRESSION_CONFIG, OVConfig from .modeling_base import OVBaseModel from .modeling_decoder import OVBaseDecoderModel +from .modeling_seq2seq import OVModelForSeq2SeqLM from .utils import ( MAX_ONNX_OPSET, MIN_ONNX_QDQ_OPSET, @@ -187,6 +188,15 @@ def quantize( remove_unused_columns, **kwargs, ) + if isinstance(self.model, OVModelForSeq2SeqLM) and self.model.use_cache: + self._quantize_ovmodelforseq2seqlm( + calibration_dataset, + save_directory, + batch_size, + data_collator, + remove_unused_columns, + **kwargs, + ) elif isinstance(self.model, OVBaseModel): self._quantize_ovbasemodel( calibration_dataset, @@ -239,6 +249,44 @@ def _quantize_ovbasemodel( ) self.model.model = quantized_model self.model.save_pretrained(save_directory) + + def _quantize_ovmodelforseq2seqlm( + self, + calibration_dataset: Dataset, + save_directory: Union[str, Path], + batch_size: int = 1, + data_collator: Optional[DataCollator] = None, + remove_unused_columns: bool = True, + **kwargs, + ): + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + calibration_dataloader = self._get_calibration_dataloader( + calibration_dataset=calibration_dataset, + batch_size=batch_size, + remove_unused_columns=remove_unused_columns, + data_collator=data_collator, + ) + quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x) + + # Full quantization of encoder + quantized_model = nncf.quantize( + self.model.encoder.model, + quantization_dataset, + model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"), + fast_bias_correction=kwargs.get("fast_bias_correction", True), + **kwargs, + ) + self.model.encoder.model = quantized_model + + # Compress weights of decoders for safity + if self.model.decoder: + self.model.decoder.model = nncf.compress_weights(self.model.decoder.model) + if self.model.decoder_with_past: + self.model.decoder_with_past.model = nncf.compress_weights(self.model.decoder_with_past.model) + + self.model.save_pretrained(save_directory) def _quantize_ovcausallm( self, From c5b21b4254be7e9d725ae750af69c2ef3f9e968b Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Nov 2023 16:20:12 +0400 Subject: [PATCH 2/8] Added a test for Seq2Seq quantization --- optimum/intel/openvino/quantization.py | 22 ++++----- tests/openvino/test_exporters_cli.py | 2 +- tests/openvino/test_quantization.py | 66 +++++++++++++++++++++++--- tests/openvino/utils_tests.py | 2 +- 4 files changed, 72 insertions(+), 20 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 64cb41ab40..8424e7a391 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -255,7 +255,7 @@ def _quantize_ovbasemodel( ) self.model.model = quantized_model self.model.save_pretrained(save_directory) - + def _quantize_ovmodelforseq2seqlm( self, calibration_dataset: Dataset, @@ -275,23 +275,21 @@ def _quantize_ovmodelforseq2seqlm( data_collator=data_collator, ) quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x) - + # Full quantization of encoder - quantized_model = nncf.quantize( - self.model.encoder.model, + self.model.encoder_model = nncf.quantize( + self.model.encoder_model, quantization_dataset, - model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"), + model_type=nncf.ModelType.TRANSFORMER, fast_bias_correction=kwargs.get("fast_bias_correction", True), **kwargs, ) - self.model.encoder.model = quantized_model - + # Compress weights of decoders for safity - if self.model.decoder: - self.model.decoder.model = nncf.compress_weights(self.model.decoder.model) - if self.model.decoder_with_past: - self.model.decoder_with_past.model = nncf.compress_weights(self.model.decoder_with_past.model) - + self.model.decoder_model = nncf.compress_weights(self.model.decoder_model) + if self.model.use_cache: + self.model.decoder_with_past_model = nncf.compress_weights(self.model.decoder_with_past_model) + self.model.save_pretrained(save_directory) def _quantize_ovcausallm( diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 76a6ee629c..2042041c04 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -110,6 +110,6 @@ def test_exporters_cli_int8(self, task: str, model_type: str): expected_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] for i, model in enumerate(models): - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8 = get_num_quantized_nodes(model.model) expected = expected_int8[i] self.assertEqual(expected, num_int8) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 83667f6a80..7eef345df3 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -56,6 +56,7 @@ _TASK_TO_DATASET = { "text-generation": ("wikitext", "wikitext-2-raw-v1", "text"), "text-classification": ("glue", "sst2", "sentence"), + "text2text-generation": ("wikitext", "wikitext-2-raw-v1", "text"), } @@ -134,7 +135,7 @@ def preprocess_function(examples, tokenizer): model = model_cls.from_pretrained(tmp_dir) - num_fake_quantize, num_int8 = get_num_quantized_nodes(model) + num_fake_quantize, num_int8 = get_num_quantized_nodes(model.model) self.assertEqual(expected_fake_quantize, num_fake_quantize) self.assertEqual(expected_int8, num_int8) @@ -143,6 +144,59 @@ def preprocess_function(examples, tokenizer): self.assertTrue("logits" in outputs) +class OVQuantizerSeq2SeqTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( + ("hf-internal-testing/tiny-random-t5", 30, 32, 104, 84), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) + def test_ovmodel_hybrid_quantization( + self, + model_name, + expected_encoder_fq, + expected_encoder_int8, + expected_decoder_int8, + expected_decoder_with_past_int8, + ): + task = OVModelForSeq2SeqLM.export_feature + dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task] + + def preprocess_function(examples, tokenizer): + return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir = "t5_test" + transformers_model = OVModelForSeq2SeqLM.from_pretrained(model_name, export=True, use_cache=True) + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + quantizer = OVQuantizer.from_pretrained(transformers_model, task=task) + calibration_dataset = quantizer.get_calibration_dataset( + dataset_name, + dataset_config_name=dataset_config_name, + preprocess_function=partial(preprocess_function, tokenizer=tokenizer), + num_samples=10, + dataset_split="train", + ) + quantizer.quantize(save_directory=tmp_dir, calibration_dataset=calibration_dataset) + model = OVModelForSeq2SeqLM.from_pretrained(tmp_dir, use_cache=True) + + num_fake_quantize, num_int8 = get_num_quantized_nodes(model.encoder.model) + self.assertEqual(expected_encoder_fq, num_fake_quantize) + self.assertEqual(expected_encoder_int8, num_int8) + + _, num_int8 = get_num_quantized_nodes(model.decoder.model) + self.assertEqual(expected_decoder_int8, num_int8) + + if model.use_cache: + _, num_int8 = get_num_quantized_nodes(model.decoder_with_past.model) + self.assertEqual(expected_decoder_with_past_int8, num_int8) + + tokens = tokenizer("This is a sample input", return_tensors="pt") + outputs = model.generate(**tokens) + + class OVWeightCompressionTest(unittest.TestCase): # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = ( @@ -178,7 +232,7 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i quantizer.quantize(save_directory=tmp_dir, weights_only=True) model = model_cls.from_pretrained(tmp_dir) - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8 = get_num_quantized_nodes(model.model) self.assertEqual(expected_pt_int8, num_int8) tokens = tokenizer("This is a sample input", return_tensors="pt") @@ -203,7 +257,7 @@ def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int quantizer.quantize(save_directory=tmp_dir, weights_only=True) model = model_cls.from_pretrained(tmp_dir) - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8 = get_num_quantized_nodes(model.model) self.assertEqual(expected_ov_int8, num_int8) tokens = tokenizer("This is a sample input", return_tensors="pt") @@ -224,7 +278,7 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type): expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type] for i, model in enumerate(models): - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8 = get_num_quantized_nodes(model.model) self.assertEqual(expected_ov_int8[i], num_int8) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) @@ -240,7 +294,7 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type): models = [model] for i, model in enumerate(models): - _, num_int8 = get_num_quantized_nodes(model) + _, num_int8 = get_num_quantized_nodes(model.model) self.assertEqual(0, num_int8) @@ -351,7 +405,7 @@ def compute_metrics(p): trainer.save_model() model = OVModelForSequenceClassification.from_pretrained(tmp_dir) - num_fake_quantize, num_int8 = get_num_quantized_nodes(model) + num_fake_quantize, num_int8 = get_num_quantized_nodes(model.model) self.assertEqual(expected_fake_quantize, num_fake_quantize) self.assertEqual(expected_int8, num_int8) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 72d4a0f810..b874d5497d 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -116,7 +116,7 @@ def get_num_quantized_nodes(ov_model): num_fake_quantize = 0 num_int8 = 0 - for elem in ov_model.model.get_ops(): + for elem in ov_model.get_ops(): if "FakeQuantize" in elem.name: num_fake_quantize += 1 for i in range(elem.get_output_size()): From 78153334b5f202560bd6d241eea6e368ad65ad0c Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Nov 2023 18:00:01 +0400 Subject: [PATCH 3/8] Fixed a couple of bugs --- optimum/intel/openvino/quantization.py | 2 +- tests/openvino/test_quantization.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 8424e7a391..34c7dd24b2 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -187,7 +187,7 @@ def quantize( weights_only, **kwargs, ) - if isinstance(self.model, OVModelForSeq2SeqLM) and self.model.use_cache: + if isinstance(self.model, OVModelForSeq2SeqLM): self._quantize_ovmodelforseq2seqlm( calibration_dataset, save_directory, diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 7eef345df3..de464d10a0 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -165,7 +165,6 @@ def preprocess_function(examples, tokenizer): return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True) with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = "t5_test" transformers_model = OVModelForSeq2SeqLM.from_pretrained(model_name, export=True, use_cache=True) tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: From 0cb7f19807b5d06755edac3f5f2a69b1951e489c Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Nov 2023 10:27:27 +0400 Subject: [PATCH 4/8] Fixed test issue --- optimum/intel/openvino/quantization.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 34c7dd24b2..b85ee7517b 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -41,8 +41,8 @@ from ..utils.constant import _TASK_ALIASES from .configuration import OVConfig from .modeling_base import OVBaseModel +from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM from .modeling_decoder import OVBaseDecoderModel -from .modeling_seq2seq import OVModelForSeq2SeqLM from .utils import ( MAX_ONNX_OPSET, MIN_ONNX_QDQ_OPSET, @@ -177,23 +177,23 @@ def quantize( "In case you only want to apply quantization on the weights, please set `weights_only=True`." ) - if isinstance(self.model, OVBaseDecoderModel) and self.model.use_cache: - self._quantize_ovcausallm( + if isinstance(self.model, OVBaseModelForSeq2SeqLM): + self._quantize_ovmodelforseq2seqlm( calibration_dataset, save_directory, batch_size, data_collator, remove_unused_columns, - weights_only, **kwargs, ) - if isinstance(self.model, OVModelForSeq2SeqLM): - self._quantize_ovmodelforseq2seqlm( + if isinstance(self.model, OVBaseDecoderModel) and self.model.use_cache: + self._quantize_ovcausallm( calibration_dataset, save_directory, batch_size, data_collator, remove_unused_columns, + weights_only, **kwargs, ) elif isinstance(self.model, OVBaseModel): From b8d98f47b2a8a32885d452b14f954116b51f2153 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 13 Nov 2023 14:10:21 +0400 Subject: [PATCH 5/8] Fixed issue --- optimum/intel/openvino/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index b85ee7517b..835f0d83e8 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -186,7 +186,7 @@ def quantize( remove_unused_columns, **kwargs, ) - if isinstance(self.model, OVBaseDecoderModel) and self.model.use_cache: + elif isinstance(self.model, OVBaseDecoderModel) and self.model.use_cache: self._quantize_ovcausallm( calibration_dataset, save_directory, From 35a30a011e295b7b6c7779e2ce679f91e00e4946 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 14 Nov 2023 09:57:51 +0400 Subject: [PATCH 6/8] Fixed test --- tests/openvino/test_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index de464d10a0..6884cba081 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -146,7 +146,7 @@ def preprocess_function(examples, tokenizer): class OVQuantizerSeq2SeqTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( - ("hf-internal-testing/tiny-random-t5", 30, 32, 104, 84), + ("hf-internal-testing/tiny-random-t5", 30, 32, 52, 84), ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) From 6df83299aed237968126ca61082aa1da4ec27c83 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 14 Nov 2023 11:00:15 +0400 Subject: [PATCH 7/8] Fixed test --- tests/openvino/test_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 6884cba081..00fa36b21b 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -146,7 +146,7 @@ def preprocess_function(examples, tokenizer): class OVQuantizerSeq2SeqTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( - ("hf-internal-testing/tiny-random-t5", 30, 32, 52, 84), + ("hf-internal-testing/tiny-random-t5", 30, 32, 52, 42), ) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) From 04f8344656d08f655f561be4c4a2284c7913468c Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 14 Nov 2023 11:06:13 +0400 Subject: [PATCH 8/8] Style --- tests/openvino/test_quantization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 00fa36b21b..5f25ca09b9 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -145,9 +145,7 @@ def preprocess_function(examples, tokenizer): class OVQuantizerSeq2SeqTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( - ("hf-internal-testing/tiny-random-t5", 30, 32, 52, 42), - ) + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("hf-internal-testing/tiny-random-t5", 30, 32, 52, 42),) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) def test_ovmodel_hybrid_quantization(