Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collecting duplicate tensors in quantization calibration dataset #577

10 changes: 7 additions & 3 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import logging
import os
Expand Down Expand Up @@ -87,11 +88,14 @@ def __init__(self, request, data_cache=None):
self.data_cache = data_cache

def __call__(self, *args, **kwargs):
self.data_cache.append(*args)
# If __call__ is invoked then self.request must be an instance of CompiledModel
signature = inspect.signature(self.request)
bound_args = signature.bind(*args, **kwargs).arguments
self.data_cache.append(copy.deepcopy(bound_args["inputs"]))
return self.request(*args, **kwargs)

def infer(self, inputs: Any = None, share_inputs: bool = False):
self.data_cache.append(inputs)
self.data_cache.append(copy.deepcopy(inputs))
return self.request.infer(inputs, share_inputs)

def start_async(
Expand All @@ -102,7 +106,7 @@ def start_async(
*,
shared_memory: Any = None,
):
self.data_cache.append(inputs)
self.data_cache.append(copy.deepcopy(inputs))
self.request.infer(inputs, share_inputs, share_outputs=True)

def wait(self):
Expand Down
40 changes: 40 additions & 0 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import tempfile
import unittest
from collections import defaultdict
from functools import partial

import evaluate
import numpy as np
import torch
from datasets import load_dataset
from parameterized import parameterized
import openvino.runtime as ov
Expand All @@ -30,6 +32,7 @@
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoTokenizer,
AutoProcessor,
TrainingArguments,
default_data_collator,
)
Expand All @@ -45,6 +48,7 @@
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForTokenClassification,
OVModelForSpeechSeq2Seq,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
OVQuantizer,
Expand All @@ -54,6 +58,7 @@


from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
from optimum.intel.openvino.quantization import InferRequestWrapper
from optimum.intel.utils.import_utils import is_openvino_version
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8

Expand Down Expand Up @@ -589,3 +594,38 @@ def compute_metrics(p):
tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
self.assertTrue("logits" in outputs)


class InferRequestWrapperTest(unittest.TestCase):
MODEL_ID = ("openai/whisper-tiny.en",)

@staticmethod
def _generate_random_audio_data(processor):
t = np.linspace(0, 1.0, int(1000), endpoint=False)
audio_data = 0.5 * np.sin((2 + np.random.random()) * np.pi * t)
input_features = processor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
).input_features
return input_features

@parameterized.expand(MODEL_ID)
def test_calibration_data_uniqueness(self, model_id):
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True)
processor = AutoProcessor.from_pretrained(model_id)

calibration_data = []
ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request, calibration_data)
for _ in range(2):
input_features = self._generate_random_audio_data(processor)
ov_model.generate(input_features)

data_hashes_per_key = defaultdict(list)
for inputs_dict in calibration_data:
for k, v in inputs_dict.items():
x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()
data_hashes_per_key[k].append(hash(x.tobytes()))
for k, data_hashes in data_hashes_per_key.items():
# All hashes can not be equal because calibration dataset contains at least 2 different samples
self.assertTrue(any(data_hashes[0] != it for it in data_hashes))
Loading