Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jan 29, 2024
1 parent 7818f5a commit 9429c01
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 7 deletions.
3 changes: 1 addition & 2 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
model.config.return_dict = "text-generation" not in task

model.config.return_dict = task not in {"text-generation", "audio-classification"}
# check if the model_inputs is correct.
model(**model_inputs)

Expand Down
52 changes: 47 additions & 5 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __init__(
self.model.to(self._device)
self.model_save_dir = model_save_dir

self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
AutoConfig.register(self.base_model_prefix, AutoConfig)
Expand Down Expand Up @@ -172,8 +175,22 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
torch.jit.save(self.model, output_path)

def forward(self, *args, **kwargs):
outputs = self.model(*args, **kwargs)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])

def eval(self):
Expand Down Expand Up @@ -212,11 +229,39 @@ class IPEXModelForImageClassification(IPEXModel):
auto_model_class = AutoModelForImageClassification
export_feature = "image-classification"

def forward(
self,
pixel_values: torch.Tensor,
**kwargs,
):
inputs = {
"pixel_values": pixel_values,
}

outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForAudioClassification(IPEXModel):
auto_model_class = AutoModelForAudioClassification
export_feature = "audio-classification"

def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_values": input_values,
}

if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForQuestionAnswering(IPEXModel):
auto_model_class = AutoModelForQuestionAnswering
Expand Down Expand Up @@ -245,9 +290,6 @@ def __init__(

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
self.use_cache = "past_key_values" in self.input_names

if use_cache ^ self.use_cache:
Expand Down
97 changes: 97 additions & 0 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
import time
import unittest

import numpy as np
import requests
import torch
from parameterized import parameterized
from PIL import Image
from transformers import (
AutoFeatureExtractor,
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoTokenizer,
Expand All @@ -30,7 +34,9 @@
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.intel import (
IPEXModel,
IPEXModelForAudioClassification,
IPEXModelForCausalLM,
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSequenceClassification,
Expand All @@ -42,13 +48,15 @@

MODEL_NAMES = {
"albert": "hf-internal-testing/tiny-random-albert",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
"bert": "hf-internal-testing/tiny-random-bert",
"bart": "hf-internal-testing/tiny-random-bart",
"blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel",
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
"codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM",
"convnext": "hf-internal-testing/tiny-random-convnext",
"distilbert": "hf-internal-testing/tiny-random-distilbert",
"electra": "hf-internal-testing/tiny-random-electra",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
Expand All @@ -57,17 +65,25 @@
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"marian": "sshleifer/tiny-marian-en-de",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistral": "echarlaix/tiny-random-mistral",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mt5": "stas/mt5-tiny-random",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-roberta",
"roformer": "hf-internal-testing/tiny-random-roformer",
"squeezebert": "hf-internal-testing/tiny-random-squeezebert",
"t5": "hf-internal-testing/tiny-random-t5",
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"vit": "hf-internal-testing/tiny-random-vit",
"wav2vec2": "anton-l/wav2vec2-random-tiny-classifier",
"xlm": "hf-internal-testing/tiny-random-xlm",
}

Expand Down Expand Up @@ -266,3 +282,84 @@ def test_compare_with_and_without_past_key_values(self):
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)


class IPEXModelForAudioClassificationTest(unittest.TestCase):
IPEX_MODEL_CLASS = IPEXModelForAudioClassification
SUPPORTED_ARCHITECTURES = (
"unispeech",
"wav2vec2",
)

def _generate_random_audio_data(self):
np.random.seed(10)
t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False)
# generate pure sine wave at 220 Hz
audio_data = 0.5 * np.sin(2 * np.pi * 220 * t)
return audio_data

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True)
self.assertIsInstance(ipex_model.config, PretrainedConfig)
transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id)
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
inputs = preprocessor(self._generate_random_audio_data(), return_tensors="pt")
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
outputs = ipex_model(**inputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True)
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
pipe = pipeline("audio-classification", model=model, feature_extractor=preprocessor)
outputs = pipe([np.random.random(16000)])
self.assertEqual(pipe.device, model.device)
self.assertTrue(all(item["score"] > 0.0 for item in outputs[0]))


class IPEXModelForImageClassificationIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
"beit",
# "levit",
"mobilenet_v1",
"mobilenet_v2",
"mobilevit",
"resnet",
"vit",
)
IPEX_MODEL_CLASS = IPEXModelForImageClassification

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True)
self.assertIsInstance(ipex_model.config, PretrainedConfig)
transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id)
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = preprocessor(images=image, return_tensors="pt")
with torch.no_grad():
transformers_outputs = transformers_model(**inputs)
outputs = ipex_model(**inputs)
self.assertIn("logits", outputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True)
preprocessor = AutoFeatureExtractor.from_pretrained(model_id)
pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor)
outputs = pipe("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(pipe.device, model.device)
self.assertGreaterEqual(outputs[0]["score"], 0.0)
self.assertTrue(isinstance(outputs[0]["label"], str))

0 comments on commit 9429c01

Please sign in to comment.