Skip to content

Commit

Permalink
Fix nncf quantization for decoder models (#727)
Browse files Browse the repository at this point in the history
* Fix nncf quantization for decoder models

* add test

* update op quant op

* remove deprecated warning

* update expected quantized

* enable stateful

* style
  • Loading branch information
echarlaix authored May 24, 2024
1 parent 1319d7b commit e22b2fd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
5 changes: 3 additions & 2 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
from .configuration import _DEFAULT_4BIT_CONFIGS, OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, STR_TO_OV_TYPE


if TYPE_CHECKING:
Expand Down Expand Up @@ -409,6 +409,7 @@ def prepare_inputs(
elif self.use_cache:
for input_name in self.key_value_input_names:
model_inputs = self.model.input(input_name)
dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()]
shape = model_inputs.get_partial_shape()
if self.config.model_type == "chatglm":
shape[0] = 0
Expand All @@ -419,7 +420,7 @@ def prepare_inputs(
shape[2] = 0
else:
shape[1] = 0
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
inputs[input_name] = np.empty([dim.get_length() for dim in shape], dtype=dtype)
else:
# past_key_values are not used explicitly, instead they are handled inside the model
if past_key_values is None:
Expand Down
9 changes: 4 additions & 5 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def _quantize_ovbasemodel(
remove_unused_columns=remove_unused_columns,
data_collator=data_collator,
)

if self.model.export_feature == "text-generation" and self.model.use_cache:
calibration_dataset = self._prepare_text_generation_dataset(
quantization_config, calibration_dataloader
Expand Down Expand Up @@ -430,6 +429,7 @@ def _quantize_ovbasemodel(
),
**kwargs,
)

self.model.model = quantized_model
if save_directory is not None:
self.model.save_pretrained(save_directory)
Expand Down Expand Up @@ -696,24 +696,23 @@ def _prepare_builtin_dataset(self, quantization_config: OVWeightQuantizationConf
def _prepare_text_generation_dataset(
self, quantization_config: OVQuantizationConfig, calibration_dataloader: OVDataLoader
) -> nncf.Dataset:
# TODO: this function is not covered by tests, remove if not relevant anymore or cover by tests otherwise

# Prefetch past_key_values
self.model.update_pkv_precision(True)
self.model.compile()
collected_inputs = []

num_samples = quantization_config.num_samples or 200

self.model.request = InferRequestWrapper(self.model.model.request, collected_inputs)
self.model.request = InferRequestWrapper(self.model.request, collected_inputs)
try:
for data in calibration_dataloader:
self.model.generate(**data, max_new_tokens=1)
if len(collected_inputs) >= num_samples:
break
finally:
self.model.model.request = self.model.model.request.request
self.model.request = self.model.request.request
calibration_dataset = nncf.Dataset(collected_inputs)

return calibration_dataset

def _prepare_unet_dataset(
Expand Down
18 changes: 10 additions & 8 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@


class OVQuantizerTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
SUPPORTED_ARCHITECTURES_TORCH_MODEL = (
(OVModelForSequenceClassification, "bert", 32, 35),
# (OVModelForCausalLM, "gpt2", 41, 23),
(OVModelForCausalLM, "gpt2", 41, 3),
)
SUPPORTED_ARCHITECTURES_OV_MODEL = (
(OVModelForSequenceClassification, "bert", 32, 35),
(OVModelForCausalLM, "gpt2", 31, 22),
)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
@parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL)
def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
model_id = MODEL_NAMES[model_name]
task = model_cls.export_feature
Expand Down Expand Up @@ -123,23 +127,21 @@ def preprocess_function(examples, tokenizer):
loaded_config = OVConfig.from_pretrained(tmp_dir)
self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict())

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
@parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL)
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
model_id = MODEL_NAMES[model_name]
task = model_cls.export_feature
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
if "gpt2" in model_id:
expected_int8 -= 1

def preprocess_function(examples, tokenizer):
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)

with tempfile.TemporaryDirectory() as tmp_dir:
transformers_model = model_cls.from_pretrained(model_id, export=True)
ov_model = model_cls.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quantizer = OVQuantizer.from_pretrained(transformers_model, task=task)
quantizer = OVQuantizer.from_pretrained(ov_model, task=task)

calibration_dataset = quantizer.get_calibration_dataset(
dataset_name,
Expand Down

0 comments on commit e22b2fd

Please sign in to comment.