Skip to content

Commit

Permalink
Merged with main
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Feb 5, 2024
2 parents 55a673b + 0f45751 commit 374b1fc
Show file tree
Hide file tree
Showing 19 changed files with 510 additions and 405 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cpu
pip install .[ipex,tests]
- name: Test with Pytest
run: |
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2
optimum-cli export openvino --model gpt2 ov_model
```

If you add `--int8`, the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision.
You can also apply 8-bit weight-only quantization when exporting your model : the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision.

```plain
optimum-cli export openvino --model gpt2 --int8 ov_model
optimum-cli export openvino --model gpt2 --weight-format int8 ov_model
```

To apply quantization on both weights and activations, you can find more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov).
Expand Down
5 changes: 1 addition & 4 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from openvino.runtime import opset13
from optimum.exporters import TasksManager
from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version
from optimum.utils.normalized_config import NormalizedConfigManager


def model_has_state(ov_model: ov.Model):
Expand Down Expand Up @@ -217,9 +216,7 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
batch_dim = 1 if config.model_type == "chatglm" else 0

fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)

normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
num_attention_heads = normalized_config.num_attention_heads if config.model_type == "bloom" else 1
num_attention_heads = config.num_attention_heads if config.model_type == "bloom" else 1
make_stateful(
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
)
7 changes: 6 additions & 1 deletion optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
"IPEXModelForMaskedLM",
"IPEXModelForTokenClassification",
"IPEXModelForQuestionAnswering",
"IPEXModelForImageClassification",
"IPEXModelForAudioClassification",
"IPEXModel",
]


try:
if not (is_openvino_available() and is_nncf_available()):
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -162,7 +164,10 @@
from .utils.dummy_ipex_objects import *
else:
from .ipex import (
IPEXModel,
IPEXModelForAudioClassification,
IPEXModelForCausalLM,
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSequenceClassification,
Expand Down
4 changes: 1 addition & 3 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
model.config.return_dict = False
model.config.return_dict = task not in {"text-generation", "audio-classification"}
# check if the model_inputs is correct.
model(**model_inputs)

torch._C._jit_set_texpr_fuser_enabled(False)
if "past_key_values" in model_inputs.keys():
model.config.return_dict = False
if is_torch_version(">=", "2.1.0"):
traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False)
else:
Expand Down
3 changes: 3 additions & 0 deletions optimum/intel/ipex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from optimum.intel.ipex.modeling_base import (
IPEXModel,
IPEXModelForAudioClassification,
IPEXModelForCausalLM,
IPEXModelForImageClassification,
IPEXModelForMaskedLM,
IPEXModelForQuestionAnswering,
IPEXModelForSequenceClassification,
Expand Down
20 changes: 1 addition & 19 deletions optimum/intel/ipex/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,13 @@
IPEXModelForMaskedLM,
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
IPEXBloomForCausalLM,
IPEXMPTForCausalLM,
IPEXOPTForCausalLM,
IPEXGPTBigCodeForCausalLM,
IPEXModelForQuestionAnswering,
)


from .utils import _HEAD_TO_AUTOMODELS


_MODEL_TYPE_TO_AUTOMODELS = {
"bloom": IPEXBloomForCausalLM,
"mpt": IPEXMPTForCausalLM,
"opt": IPEXOPTForCausalLM,
"big_code": IPEXGPTBigCodeForCausalLM,
}


logger = logging.getLogger(__name__)

IPEX_NOT_AVAILABLE_ERROR_MSG = (
Expand Down Expand Up @@ -146,13 +134,7 @@ def __enter__(self):
)
if task in _HEAD_TO_AUTOMODELS:
model = jit_trace(model, task, use_cache)
model_type = getattr(self._original.config, "model_type", "").replace("_", "-")

if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys():
auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[task]
else:
auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])

auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])
model = auto_model_class(model, self._original.config, use_cache=use_cache)

# Enable automatic mixed precision (AMP) if we are going to target `bfloat16`
Expand Down
Loading

0 comments on commit 374b1fc

Please sign in to comment.