Skip to content

Commit

Permalink
Fix collecting duplicate tensors in quantization calibration dataset (#…
Browse files Browse the repository at this point in the history
…577)

* Added deepcopying of inputs collected by InferRequestWrapper. Added a test covering the fixed issue.

* Phrasing tweaks

* Add soundfile to test requirements

* Added librosa to test requirements

* Added copying to other data cache appends

* Remove the need for real test data

* Process __call__ call properly

* Addressed suggested changes
  • Loading branch information
nikita-savelyevv authored Mar 1, 2024
1 parent 2d8307e commit 652a15c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
10 changes: 7 additions & 3 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import logging
import os
Expand Down Expand Up @@ -87,11 +88,14 @@ def __init__(self, request, data_cache=None):
self.data_cache = data_cache

def __call__(self, *args, **kwargs):
self.data_cache.append(*args)
# If __call__ is invoked then self.request must be an instance of CompiledModel
signature = inspect.signature(self.request)
bound_args = signature.bind(*args, **kwargs).arguments
self.data_cache.append(copy.deepcopy(bound_args["inputs"]))
return self.request(*args, **kwargs)

def infer(self, inputs: Any = None, share_inputs: bool = False):
self.data_cache.append(inputs)
self.data_cache.append(copy.deepcopy(inputs))
return self.request.infer(inputs, share_inputs)

def start_async(
Expand All @@ -102,7 +106,7 @@ def start_async(
*,
shared_memory: Any = None,
):
self.data_cache.append(inputs)
self.data_cache.append(copy.deepcopy(inputs))
self.request.infer(inputs, share_inputs, share_outputs=True)

def wait(self):
Expand Down
40 changes: 40 additions & 0 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import tempfile
import unittest
from collections import defaultdict
from functools import partial

import evaluate
import numpy as np
import torch
from datasets import load_dataset
from parameterized import parameterized
import openvino.runtime as ov
Expand All @@ -30,6 +32,7 @@
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoTokenizer,
AutoProcessor,
TrainingArguments,
default_data_collator,
)
Expand All @@ -45,6 +48,7 @@
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForTokenClassification,
OVModelForSpeechSeq2Seq,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
OVQuantizer,
Expand All @@ -54,6 +58,7 @@


from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
from optimum.intel.openvino.quantization import InferRequestWrapper
from optimum.intel.utils.import_utils import is_openvino_version
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8

Expand Down Expand Up @@ -589,3 +594,38 @@ def compute_metrics(p):
tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
self.assertTrue("logits" in outputs)


class InferRequestWrapperTest(unittest.TestCase):
MODEL_ID = ("openai/whisper-tiny.en",)

@staticmethod
def _generate_random_audio_data(processor):
t = np.linspace(0, 1.0, int(1000), endpoint=False)
audio_data = 0.5 * np.sin((2 + np.random.random()) * np.pi * t)
input_features = processor(
audio_data,
sampling_rate=16000,
return_tensors="pt",
).input_features
return input_features

@parameterized.expand(MODEL_ID)
def test_calibration_data_uniqueness(self, model_id):
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True)
processor = AutoProcessor.from_pretrained(model_id)

calibration_data = []
ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request, calibration_data)
for _ in range(2):
input_features = self._generate_random_audio_data(processor)
ov_model.generate(input_features)

data_hashes_per_key = defaultdict(list)
for inputs_dict in calibration_data:
for k, v in inputs_dict.items():
x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()
data_hashes_per_key[k].append(hash(x.tobytes()))
for k, data_hashes in data_hashes_per_key.items():
# All hashes can not be equal because calibration dataset contains at least 2 different samples
self.assertTrue(any(data_hashes[0] != it for it in data_hashes))

0 comments on commit 652a15c

Please sign in to comment.