Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added hybrid quantization for seq2seq models #427

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ..utils.constant import _TASK_ALIASES
from .configuration import OVConfig
from .modeling_base import OVBaseModel
from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM
from .modeling_decoder import OVBaseDecoderModel
from .utils import (
MAX_ONNX_OPSET,
Expand Down Expand Up @@ -176,7 +177,16 @@ def quantize(
"In case you only want to apply quantization on the weights, please set `weights_only=True`."
)

if isinstance(self.model, OVBaseDecoderModel) and self.model.use_cache:
if isinstance(self.model, OVBaseModelForSeq2SeqLM):
self._quantize_ovmodelforseq2seqlm(
calibration_dataset,
save_directory,
batch_size,
data_collator,
remove_unused_columns,
**kwargs,
)
elif isinstance(self.model, OVBaseDecoderModel) and self.model.use_cache:
self._quantize_ovcausallm(
calibration_dataset,
save_directory,
Expand Down Expand Up @@ -246,6 +256,42 @@ def _quantize_ovbasemodel(
self.model.model = quantized_model
self.model.save_pretrained(save_directory)

def _quantize_ovmodelforseq2seqlm(
self,
calibration_dataset: Dataset,
save_directory: Union[str, Path],
batch_size: int = 1,
data_collator: Optional[DataCollator] = None,
remove_unused_columns: bool = True,
**kwargs,
):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

calibration_dataloader = self._get_calibration_dataloader(
calibration_dataset=calibration_dataset,
batch_size=batch_size,
remove_unused_columns=remove_unused_columns,
data_collator=data_collator,
)
quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x)

# Full quantization of encoder
self.model.encoder_model = nncf.quantize(
self.model.encoder_model,
quantization_dataset,
model_type=nncf.ModelType.TRANSFORMER,
fast_bias_correction=kwargs.get("fast_bias_correction", True),
**kwargs,
)

# Compress weights of decoders for safity
self.model.decoder_model = nncf.compress_weights(self.model.decoder_model)
if self.model.use_cache:
self.model.decoder_with_past_model = nncf.compress_weights(self.model.decoder_with_past_model)
Comment on lines +288 to +291
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be a possibility to also quantize the activations for the decoder components?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible but it may hurt accuracy.


self.model.save_pretrained(save_directory)

def _quantize_ovcausallm(
self,
calibration_dataset: Dataset,
Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ def test_exporters_cli_int8(self, task: str, model_type: str):

expected_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type]
for i, model in enumerate(models):
_, num_int8 = get_num_quantized_nodes(model)
_, num_int8 = get_num_quantized_nodes(model.model)
expected = expected_int8[i]
self.assertEqual(expected, num_int8)
63 changes: 57 additions & 6 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_TASK_TO_DATASET = {
"text-generation": ("wikitext", "wikitext-2-raw-v1", "text"),
"text-classification": ("glue", "sst2", "sentence"),
"text2text-generation": ("wikitext", "wikitext-2-raw-v1", "text"),
}


Expand Down Expand Up @@ -134,7 +135,7 @@ def preprocess_function(examples, tokenizer):

model = model_cls.from_pretrained(tmp_dir)

num_fake_quantize, num_int8 = get_num_quantized_nodes(model)
num_fake_quantize, num_int8 = get_num_quantized_nodes(model.model)
self.assertEqual(expected_fake_quantize, num_fake_quantize)
self.assertEqual(expected_int8, num_int8)

Expand All @@ -143,6 +144,56 @@ def preprocess_function(examples, tokenizer):
self.assertTrue("logits" in outputs)


class OVQuantizerSeq2SeqTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("hf-internal-testing/tiny-random-t5", 30, 32, 52, 42),)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
def test_ovmodel_hybrid_quantization(
self,
model_name,
expected_encoder_fq,
expected_encoder_int8,
expected_decoder_int8,
expected_decoder_with_past_int8,
):
task = OVModelForSeq2SeqLM.export_feature
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]

def preprocess_function(examples, tokenizer):
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)

with tempfile.TemporaryDirectory() as tmp_dir:
transformers_model = OVModelForSeq2SeqLM.from_pretrained(model_name, export=True, use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
calibration_dataset = quantizer.get_calibration_dataset(
dataset_name,
dataset_config_name=dataset_config_name,
preprocess_function=partial(preprocess_function, tokenizer=tokenizer),
num_samples=10,
dataset_split="train",
)
quantizer.quantize(save_directory=tmp_dir, calibration_dataset=calibration_dataset)
model = OVModelForSeq2SeqLM.from_pretrained(tmp_dir, use_cache=True)

num_fake_quantize, num_int8 = get_num_quantized_nodes(model.encoder.model)
self.assertEqual(expected_encoder_fq, num_fake_quantize)
self.assertEqual(expected_encoder_int8, num_int8)

_, num_int8 = get_num_quantized_nodes(model.decoder.model)
self.assertEqual(expected_decoder_int8, num_int8)

if model.use_cache:
_, num_int8 = get_num_quantized_nodes(model.decoder_with_past.model)
self.assertEqual(expected_decoder_with_past_int8, num_int8)

tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model.generate(**tokens)


class OVWeightCompressionTest(unittest.TestCase):
# TODO : add models
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_COMPRESSED_MATMULS = (
Expand Down Expand Up @@ -178,7 +229,7 @@ def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_i
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
model = model_cls.from_pretrained(tmp_dir)

_, num_int8 = get_num_quantized_nodes(model)
_, num_int8 = get_num_quantized_nodes(model.model)
self.assertEqual(expected_pt_int8, num_int8)

tokens = tokenizer("This is a sample input", return_tensors="pt")
Expand All @@ -203,7 +254,7 @@ def test_ovmodel_weight_compression(self, model_cls, model_name, expected_pt_int
quantizer.quantize(save_directory=tmp_dir, weights_only=True)
model = model_cls.from_pretrained(tmp_dir)

_, num_int8 = get_num_quantized_nodes(model)
_, num_int8 = get_num_quantized_nodes(model.model)
self.assertEqual(expected_ov_int8, num_int8)

tokens = tokenizer("This is a sample input", return_tensors="pt")
Expand All @@ -224,7 +275,7 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):

expected_ov_int8 = _ARCHITECTURES_TO_EXPECTED_INT8[model_type]
for i, model in enumerate(models):
_, num_int8 = get_num_quantized_nodes(model)
_, num_int8 = get_num_quantized_nodes(model.model)
self.assertEqual(expected_ov_int8[i], num_int8)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
Expand All @@ -240,7 +291,7 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
models = [model]

for i, model in enumerate(models):
_, num_int8 = get_num_quantized_nodes(model)
_, num_int8 = get_num_quantized_nodes(model.model)
self.assertEqual(0, num_int8)


Expand Down Expand Up @@ -351,7 +402,7 @@ def compute_metrics(p):
trainer.save_model()

model = OVModelForSequenceClassification.from_pretrained(tmp_dir)
num_fake_quantize, num_int8 = get_num_quantized_nodes(model)
num_fake_quantize, num_int8 = get_num_quantized_nodes(model.model)
self.assertEqual(expected_fake_quantize, num_fake_quantize)
self.assertEqual(expected_int8, num_int8)

Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
def get_num_quantized_nodes(ov_model):
num_fake_quantize = 0
num_int8 = 0
for elem in ov_model.model.get_ops():
for elem in ov_model.get_ops():
if "FakeQuantize" in elem.name:
num_fake_quantize += 1
for i in range(elem.get_output_size()):
Expand Down
Loading