Skip to content

Commit

Permalink
Support SentenceTransformers models in optimum (#865)
Browse files Browse the repository at this point in the history
* Support sentence transformer

* FIx error

* Separate class for SentenceTransfromers and test

* Separate class interface for Sentence Transformers

* Check if object has encode method

* Formatting and change uotput checking

* Tokenizer init moved to __init__, sentence-tranformer pipeline in OVModelForFeatureExtraction was changed

* Remove model_max_length default value

* Update tests/openvino/test_modeling_sentence_transformers.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* Update tests/openvino/test_modeling.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* Move tokenizer initialization, other improvements

* Update optimum/intel/openvino/modeling_sentence_transformers.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* Renaming OVModelForSentenceTransformer to OVSentenceTransformer

* Make style

* Move checking to init

* Update tests/openvino/test_modeling.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* Add dummy_openvino_and_sentence_transformers_objects.py

* refactoring

* Update optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

* Add tests to check saving and loading model

* Add tests to check saving and loading model

* Support cls._library_name

* Fix test

* Refactoring

* Refactoring fix

* Update optimum/intel/openvino/modeling.py

* Fix test

---------

Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
Co-authored-by: Ella Charlaix <ella@huggingface.co>
  • Loading branch information
4 people authored Sep 6, 2024
1 parent 3586b5b commit 4dc4d57
Show file tree
Hide file tree
Showing 13 changed files with 324 additions and 2 deletions.
28 changes: 28 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_neural_compressor_available,
is_nncf_available,
is_openvino_available,
is_sentence_transformers_available,
)
from .version import __version__

Expand Down Expand Up @@ -179,6 +180,21 @@
_import_structure["neural_compressor"].append("INCStableDiffusionPipeline")


try:
if not (is_openvino_available() and is_sentence_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_openvino_and_sentence_transformers_objects"] = [
"OVSentenceTransformer",
]
else:
_import_structure["openvino"].extend(
[
"OVSentenceTransformer",
]
)


if TYPE_CHECKING:
try:
if not is_ipex_available():
Expand Down Expand Up @@ -302,6 +318,18 @@
else:
from .neural_compressor import INCStableDiffusionPipeline

try:
if not (is_openvino_available() and is_sentence_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_openvino_and_sentence_transformers_objects import (
OVSentenceTransformer,
)
else:
from .openvino import (
OVSentenceTransformer,
)

else:
import sys

Expand Down
11 changes: 10 additions & 1 deletion optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import logging
import warnings

from ..utils.import_utils import is_accelerate_available, is_diffusers_available, is_nncf_available
from ..utils.import_utils import (
is_accelerate_available,
is_diffusers_available,
is_nncf_available,
is_sentence_transformers_available,
)
from .utils import (
OV_DECODER_NAME,
OV_DECODER_WITH_PAST_NAME,
Expand Down Expand Up @@ -77,3 +82,7 @@
OVStableDiffusionXLImg2ImgPipeline,
OVStableDiffusionXLPipeline,
)


if is_sentence_transformers_available():
from .modeling_sentence_transformers import OVSentenceTransformer
7 changes: 7 additions & 0 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ class OVModelForFeatureExtraction(OVModel):
auto_model_class = AutoModel

def __init__(self, model=None, config=None, **kwargs):
if {"token_embeddings", "sentence_embedding"}.issubset(
{name for output in model.outputs for name in output.names}
): # Sentence Transormers outputs
raise ValueError(
"This model is a Sentence Transformers model. Please use `OVSentenceTransformer` to load this model."
)

super().__init__(model, config, **kwargs)

@add_start_docstrings_to_model_forward(
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class OVBaseModel(OptimizedModel):
auto_model_class = None
export_feature = None
_supports_cache_class = False
_library_name = "transformers"

def __init__(
self,
Expand Down Expand Up @@ -501,6 +502,7 @@ def _from_transformers(
force_download=force_download,
trust_remote_code=trust_remote_code,
ov_config=ov_config,
library_name=cls._library_name,
)

config.save_pretrained(save_dir_path)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _from_transformers(
ov_config=ov_export_config,
stateful=stateful,
model_loading_kwargs=model_loading_kwargs,
library_name=cls._library_name,
)

config.is_decoder = True
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class OVStableDiffusionPipelineBase(OVBaseModel, OVTextualInversionLoaderMixin):
auto_model_class = StableDiffusionPipeline
config_name = "model_index.json"
export_feature = "text-to-image"
_library_name = "diffusers"

def __init__(
self,
Expand Down Expand Up @@ -372,6 +373,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
ov_config=ov_config,
library_name=cls._library_name,
)

return cls._from_pretrained(
Expand Down
142 changes: 142 additions & 0 deletions optimum/intel/openvino/modeling_sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from pathlib import Path
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, PretrainedConfig
from transformers.file_utils import add_start_docstrings

from .modeling import MODEL_START_DOCSTRING, OVModel


@add_start_docstrings(
"""
OpenVINO Model for feature extraction tasks for Sentence Transformers.
""",
MODEL_START_DOCSTRING,
)
class OVSentenceTransformer(OVModel):
export_feature = "feature-extraction"
_library_name = "sentence_transformers"

def __init__(self, model=None, config=None, tokenizer=None, **kwargs):
super().__init__(model, config, **kwargs)

self.encode = MethodType(SentenceTransformer.encode, self)
self._text_length = MethodType(SentenceTransformer._text_length, self)
self.default_prompt_name = None
self.truncate_dim = None
self.tokenizer = tokenizer

def _save_pretrained(self, save_directory: Union[str, Path]):
super()._save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)

def forward(self, inputs: Dict[str, torch.Tensor]):
self.compile()
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
token_type_ids = inputs.get("token_type_ids")

np_inputs = isinstance(input_ids, np.ndarray)
if not np_inputs:
input_ids = np.array(input_ids)
attention_mask = np.array(attention_mask)
token_type_ids = np.array(token_type_ids) if token_type_ids is not None else token_type_ids

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

# Add the token_type_ids when needed
if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids if token_type_ids is not None else np.zeros_like(input_ids)

outputs = self._inference(inputs)
return {
"token_embeddings": torch.from_numpy(outputs["token_embeddings"]).to(self.device),
"sentence_embedding": torch.from_numpy(outputs["sentence_embedding"]).to(self.device),
}

@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
file_name: Optional[str] = None,
subfolder: str = "",
from_onnx: bool = False,
local_files_only: bool = False,
**kwargs,
):
trust_remote_code = kwargs.pop("trust_remote_code", False)
tokenizer_kwargs = kwargs.pop("tokenizer_kwargs", None)

tokenizer_args = {
"token": token,
"trust_remote_code": trust_remote_code,
"revision": revision,
"local_files_only": local_files_only,
}
if tokenizer_kwargs:
kwargs["tokenizer_args"].update(tokenizer_kwargs)

tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_args)

return super()._from_pretrained(
model_id=model_id,
config=config,
token=token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=file_name,
subfolder=subfolder,
from_onnx=from_onnx,
local_files_only=local_files_only,
tokenizer=tokenizer,
**kwargs,
)

def tokenize(
self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True
) -> Dict[str, torch.Tensor]:
"""Tokenizes a text and maps tokens to token-ids"""
output = {}
if isinstance(texts[0], str):
to_tokenize = [texts]
elif isinstance(texts[0], dict):
to_tokenize = []
output["text_keys"] = []
for lookup in texts:
text_key, text = next(iter(lookup.items()))
to_tokenize.append(text)
output["text_keys"].append(text_key)
to_tokenize = [to_tokenize]
else:
batch1, batch2 = [], []
for text_tuple in texts:
batch1.append(text_tuple[0])
batch2.append(text_tuple[1])
to_tokenize = [batch1, batch2]

# strip
to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]

output.update(
self.tokenizer(
*to_tokenize,
padding=padding,
truncation="longest_first",
return_tensors="pt",
)
)
return output
1 change: 1 addition & 0 deletions optimum/intel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_nncf_available,
is_numa_available,
is_openvino_available,
is_sentence_transformers_available,
is_torch_version,
is_transformers_available,
is_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .import_utils import DummyObject, requires_backends


class OVSentenceTransformer(metaclass=DummyObject):
_backends = ["openvino", "sentence_transformers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["openvino", "sentence_transformers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino", "sentence_transformers"])
14 changes: 14 additions & 0 deletions optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@
_numa_available = False


_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None
_sentence_transformers_available = "N/A"

if _sentence_transformers_available:
try:
_sentence_transformers_available = importlib_metadata.version("sentence_transformers")
except importlib_metadata.PackageNotFoundError:
_sentence_transformers_available = False


def is_transformers_available():
return _transformers_available

Expand Down Expand Up @@ -280,6 +290,10 @@ def is_accelerate_available():
return _accelerate_available


def is_sentence_transformers_available():
return _sentence_transformers_available


def is_numa_available():
return _numa_available

Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForTokenClassification,
OVSentenceTransformer,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
)
Expand Down Expand Up @@ -316,5 +317,5 @@ def test_exporters_cli_sentence_transformers(self):
shell=True,
check=True,
)
model = eval(_HEAD_TO_AUTOMODELS["feature-extraction"]).from_pretrained(tmpdir, compile=False)
model = OVSentenceTransformer.from_pretrained(tmpdir, compile=False)
self.assertFalse("last_hidden_state" in model.output_names)
15 changes: 15 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
OVModelForVision2Seq,
OVSentenceTransformer,
OVStableDiffusionPipeline,
)
from optimum.intel.openvino import OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, OV_ENCODER_NAME, OV_XML_FILE_NAME
Expand Down Expand Up @@ -655,6 +656,20 @@ def test_pipeline(self, model_arch):
del model
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_sentence_transformers_pipeline(self, model_arch):
"""
Check if we call OVModelForFeatureExtraction passing saved ir-model with outputs
from Sentence Transformers then an appropriate exception raises.
"""
model_id = MODEL_NAMES[model_arch]
with tempfile.TemporaryDirectory() as tmp_dir:
save_dir = str(tmp_dir)
OVSentenceTransformer.from_pretrained(model_id, export=True).save_pretrained(save_dir)
with self.assertRaises(Exception) as context:
OVModelForFeatureExtraction.from_pretrained(save_dir)
self.assertIn("Please use `OVSentenceTransformer`", str(context.exception))


class OVModelForCausalLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
Expand Down
Loading

0 comments on commit 4dc4d57

Please sign in to comment.