Skip to content

add IPEX model and README #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

🤗 Optimum Intel is the interface between the 🤗 Transformers and Diffusers libraries and the different tools and libraries provided by Intel to accelerate end-to-end pipelines on Intel architectures.

[Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#introduction) is an open-source library which provides optimizations for both eager mode and graph mode, however, compared to eager mode, graph mode in PyTorch* normally yields better performance from optimization techniques, such as operation fusion.

Intel [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) is an open-source library enabling the usage of the most popular compression techniques such as quantization, pruning and knowledge distillation. It supports automatic accuracy-driven tuning strategies in order for users to easily generate quantized model. The users can easily apply static, dynamic and aware-training quantization approaches while giving an expected accuracy criteria. It also supports different weight pruning techniques enabling the creation of pruned model giving a predefined sparsity target.

[OpenVINO](https://docs.openvino.ai/latest/index.html) is an open-source toolkit that enables high performance inference capabilities for Intel CPUs, GPUs, and special DL inference accelerators ([see](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) the full list of supported devices). It is supplied with a set of tools to optimize your models with compression techniques such as quantization, pruning and knowledge distillation. Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime.
Expand All @@ -19,6 +21,7 @@ To install the latest release of 🤗 Optimum Intel with the corresponding requi
|:-----------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------|
| [Intel Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) | `pip install --upgrade-strategy eager "optimum[neural-compressor]"` |
| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `pip install --upgrade-strategy eager "optimum[openvino,nncf]"` |
| [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#introduction) | `pip install --upgrade-strategy eager "optimum[ipex]"` |

The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.

Expand All @@ -37,7 +40,7 @@ or to install from source including dependencies:
python -m pip install "optimum-intel[extras]"@git+https://github.com/huggingface/optimum-intel.git
```

where `extras` can be one or more of `neural-compressor`, `openvino`, `nncf`.
where `extras` can be one or more of `ipex`, `neural-compressor`, `openvino`, `nncf`.

# Quick tour

Expand Down Expand Up @@ -199,6 +202,28 @@ Quantization aware training (QAT) is applied in order to simulate the effects of
You can find more examples in the [documentation](https://huggingface.co/docs/optimum/intel/index).


## Intel Extension for PyTorch

To load a model and run generation with IPEX graph mode, you can just replace your `AutoModelForXxx` class with the corresponding `IPEXModelForXxx` class.

```diff
import torch
from transformers import AutoTokenizer, pipeline
- from transformers import AutoModelForCausalLM
+ from optimum.intel.ipex import IPEXModelForCausalLM


model_id = "gpt2"
- model = AutoModelForCausalLM.from_pretrained(model_id)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
results = text_generator("This is an example input")
```

For now, we only support text-generation tasks.


## Running the examples

Check out the [`examples`](https://github.com/huggingface/optimum-intel/tree/main/examples) directory to see how 🤗 Optimum Intel can be used to optimize models and accelerate inference.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference_inc.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ specific language governing permissions and limitations under the License.

## INCModelForCausalLM

[[autodoc]] neural_compressor.modeling_decoder.INCModelForCausalLM
[[autodoc]] neural_compressor.modeling_base.INCModelForCausalLM

## INCModelForSeq2SeqLM

[[autodoc]] neural_compressor.modeling_base.INCModelForSeq2SeqLM
[[autodoc]] neural_compressor.modeling_base.INCModelForSeq2SeqLM
25 changes: 21 additions & 4 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,20 @@
if not is_ipex_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_ipex_objects"] = ["inference_mode"]
from .utils import dummy_ipex_objects

_import_structure["utils.dummy_ipex_objects"] = [
name for name in dir(dummy_ipex_objects) if not name.startswith("_")
]
else:
_import_structure["ipex"] = ["inference_mode"]
_import_structure["ipex"] = [
"inference_mode",
"IPEXModelForCausalLM",
"IPEXModelForSequenceClassification",
"IPEXModelForMaskedLM",
"IPEXModelForTokenClassification",
]


try:
if not (is_openvino_available() and is_nncf_available()):
Expand Down Expand Up @@ -144,9 +155,15 @@
if not is_ipex_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_ipex_objects import inference_mode
from .utils.dummy_ipex_objects import *
else:
from .ipex import inference_mode
from .ipex import (
IPEXModelForCausalLM,
IPEXModelForMaskedLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
inference_mode,
)

try:
if not (is_openvino_available() and is_nncf_available()):
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self.model_save_dir = model_save_dir
self.preprocessors = kwargs.get("preprocessors", [])
self.use_cache = use_cache
## TO do: add XPU support
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
Expand Down Expand Up @@ -282,7 +283,6 @@ def forward(
inputs["position_ids"] = position_ids

model_type = self.config.model_type.replace("_", "-")

if self.use_cache:
if past_key_values is None:
nb_pkv = 2
Expand Down
7 changes: 7 additions & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
from optimum.intel.ipex.modeling_base import (
IPEXModelForCausalLM,
IPEXModelForMaskedLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)

from .inference import inference_mode
263 changes: 263 additions & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import logging
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union

import intel_extension_for_pytorch as ipex
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
GenerationMixin,
PretrainedConfig,
)
from transformers.models.auto.auto_factory import _get_model_class
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.modeling_base import OptimizedModel

from ..generation.modeling import BaseModelForCausalLM, jit_trace
from ..utils.import_utils import is_torch_version
from ..utils.modeling_utils import patch_decoder_attention_mask


# from .utils import generation_tasks


SUPPORT_MODEL_LIST_FOR_CAUSAL_LM = {
# "llama": LlamaForCausalLM
}

SUPPORT_TASK_LIST = {"text-generation": SUPPORT_MODEL_LIST_FOR_CAUSAL_LM}


logger = logging.getLogger(__name__)


class IPEXModel(OptimizedModel):
auto_model_class = AutoModel
export_feature = "feature-extraction"
base_model_prefix = "ipex_model"

def __init__(
self,
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
**kwargs,
):
OptimizedModel.__init__(self, model=model, config=config)
# To do: add XPU support
self._device = torch.device("cpu")
self.model.to(self._device)

# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
AutoConfig.register(self.base_model_prefix, AutoConfig)
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)

@classmethod
def _from_transformers(
cls,
model_id: str,
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
use_cache: bool = True,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
**kwargs,
):
if is_torch_version("<", "2.1.0"):
raise ImportError("`torch>=2.0.0` is needed to trace your model")

task = cls.export_feature
model_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"torch_dtype": torch_dtype,
"device": "cpu",
}
model_type = None
support_ipex_transformers = False
if task in SUPPORT_TASK_LIST.keys():
for name in SUPPORT_TASK_LIST[task].keys():
if name in model_id:
support_ipex_transformers = True
model_type = name
break

if support_ipex_transformers and task in SUPPORT_TASK_LIST and model_type in SUPPORT_TASK_LIST[task]:
# model = SUPPORT_TASK_LIST[task][model_type].from_pretrained(model_id, **model_kwargs)
pass
else:
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
model = patch_decoder_attention_mask(model)

model = ipex.optimize(model, dtype=torch_dtype, level="O1", auto_kernel_selection=True)

if kwargs.pop("jit", True):
try:
traced_model = cls.apply_jit_optimize(model, task, use_cache, support_ipex_transformers)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
config.torchscript = True

return cls._from_pretrained(
model_id=save_dir_path,
config=config,
use_cache=use_cache,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
model_dtype=torch_dtype,
**kwargs,
)
except Exception as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

return cls(model, config=config, use_cache=use_cache, model_dtype=torch_dtype, **kwargs)

@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = WEIGHTS_NAME,
local_files_only: bool = False,
subfolder: str = "",
use_cache: bool = True,
**kwargs,
):
# Load the model from local directory
if os.path.isdir(model_id):
model_cache_path = os.path.join(model_id, file_name)
model_save_dir = model_id
# Download the model from the hub
else:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
subfolder=subfolder,
)
model_save_dir = Path(model_cache_path).parent

if getattr(config, "torchscript", False):
model = torch.jit.load(model_cache_path)
torch.jit.freeze(model.eval())
else:
model_class = _get_model_class(config, cls.auto_model_class._model_mapping)
model = model_class.from_pretrained(model_save_dir)

return cls(model, config=config, model_save_dir=model_save_dir, use_cache=use_cache, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)

if isinstance(self.model, torch.nn.Module):
state_dict = self.model.state_dict()
torch.save(state_dict, output_path)
else:
torch.jit.save(self.model, output_path)

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def eval(self):
self.model.eval()
return self

@property
def device(self) -> torch.device:
return self._device

def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
return self

def can_generate(self):
return isinstance(self.model, GenerationMixin)

def generate(self, *args, **kwargs):
if not self.can_generate():
raise TypeError(
f"The current model class {self.model.__class__} is not compatible with `.generate()`, as it doesn't have a language model head."
)
return self.model.generate(*args, **kwargs)

@classmethod
def apply_jit_optimize(cls, model, task, use_cache, support_ipex_transformers=False):
return jit_trace(model, task, use_cache)


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
export_feature = "text-classification"


class IPEXModelForMaskedLM(IPEXModel):
auto_model_class = AutoModelForMaskedLM
export_feature = "fill-mask"


class IPEXModelForTokenClassification(IPEXModel):
auto_model_class = AutoModelForTokenClassification
export_feature = "token-classification"


class IPEXModelForCausalLM(IPEXModel, BaseModelForCausalLM):
auto_model_class = AutoModelForCausalLM
export_feature = "text-generation"
forward = BaseModelForCausalLM.forward
generate = BaseModelForCausalLM.generate
can_generate = BaseModelForCausalLM.can_generate

def __init__(
self,
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
**kwargs,
):
IPEXModel.__init__(self, model, config)
BaseModelForCausalLM.__init__(self, model, config, model_save_dir, use_cache, **kwargs)

@classmethod
def apply_jit_optimize(cls, model, task, use_cache, support_ipex_transformers):
if not support_ipex_transformers:
return jit_trace(model, task, use_cache)
else:
# from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
# dummy_jit_inputs = get_dummy_input(task, model) # From ipex
# model = torch.jit.trace(model, example_input_kwargs=dummy_jit_inputs)
return model
1 change: 1 addition & 0 deletions optimum/intel/ipex/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
generation_tasks = ("text-generation",)
Loading