Skip to content

Commit

Permalink
add tests and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Jan 8, 2024
1 parent 1b89624 commit 2bf2122
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 0 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,46 @@ where `extras` can be one or more of `neural-compressor`, `openvino`, `nncf`.

# Quick tour

## IPEX
### pipeline
Hugging Face pipelines provide a simple yet powerful abstraction to quickly set up inference. If you already have a pipeline from transformers, you can unlock the performance benefits of Optimum-Intel by just changing one line.
```diff
import torch
- from transformers.pipelines import pipeline
+ from optimum.intel.pipelines import pipeline

pipe = pipeline('text-generation', 'gpt2', torch_dtype=torch.bfloat16)
pipe("Describe a real-world application of AI in sustainable energy.")
```

### generate
If you want control over advanced features like quantization and token selection strategies, we recommend using the generate() API. Just like with pipelines, switching from existing transformers code is super simple.
```diff
import torch
from transformers import AutoTokenizer, AutoConfig
- from transformers import AutoModelForCausalLM
+ from optimum.intel.generation.modeling import TSModelForCausalLM

name = 'gpt2'
config = AutoConfig.from_pretrained(name, trust_remote_code=True)

model = TSModelForCausalLM.from_pretrained(
name,
config=config,
torch_dtype=torch.bfloat16,
export=True,
)

tokenizer = AutoTokenizer.from_pretrained(name)
input_sentence = ["Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?"]
model_inputs = tokenizer(input_sentence, return_tensors="pt")
generation_kwargs = dict(max_new_tokens=32, do_sample=False, num_beams=4, num_beam_groups=1, no_repeat_ngram_size=2, use_cache=True)

generated_ids = model.generate(**model_inputs, **generation_kwargs)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(output)
```

## Neural Compressor

Dynamic quantization can be used through the Optimum command-line interface:
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,5 +428,6 @@ def _from_transformers(
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
model_dtype=torch_dtype,
**kwargs,
)
5 changes: 5 additions & 0 deletions optimum/intel/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
cached_file,
extract_commit_hash,
is_ipex_available,
is_offline_mode,
is_torch_available,
logging,
Expand All @@ -39,6 +40,10 @@
from ..generation.modeling import TSModelForCausalLM


if is_ipex_available():
import intel_extension_for_pytorch


if is_torch_available():
import torch
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
Expand Down
53 changes: 53 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized
from transformers.pipelines import pipeline as transformers_pipeline

from optimum.intel.generation.modeling import TSModelForCausalLM
from optimum.intel.pipelines import pipeline as ipex_pipeline


MODEL_NAMES = {
"bert": "hf-internal-testing/tiny-random-bert",
"distilbert": "hf-internal-testing/tiny-random-distilbert",
"roberta": "hf-internal-testing/tiny-random-roberta",
"bloom": "hf-internal-testing/tiny-random-bloom",
"gptj": "hf-internal-testing/tiny-random-gptj",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
}


class PipelinesIntegrationTest(unittest.TestCase):
TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("bloom", "gptj", "gpt2", "gpt_neo")

@parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES)
def test_text_generation_pipeline_inference(self, model_arch):
model_id = MODEL_NAMES[model_arch]
inputs = "DeepSpeed is a machine learning framework for deep neural networks and deep reinforcement learning. It is written in C++ and is available for Linux, Mac OS X,"
transformers_text_generator = transformers_pipeline("text-generation", model_id)
ipex_text_generator = ipex_pipeline("text-generation", model_id)
with torch.inference_mode():
transformers_output = transformers_text_generator(inputs)
with torch.inference_mode():
ipex_output = ipex_text_generator(inputs)
self.assertTrue(isinstance(ipex_text_generator.model, TSModelForCausalLM))
self.assertTrue(isinstance(ipex_text_generator.model.model, torch.jit.RecursiveScriptModule))
self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"])

0 comments on commit 2bf2122

Please sign in to comment.