diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 7f76c28543..c70dc4676b 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -24,6 +24,7 @@ is_neural_compressor_available, is_nncf_available, is_openvino_available, + is_sentence_transformers_available, ) from .version import __version__ @@ -179,6 +180,21 @@ _import_structure["neural_compressor"].append("INCStableDiffusionPipeline") +try: + if not (is_openvino_available() and is_sentence_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + _import_structure["utils.dummy_openvino_and_sentence_transformers_objects"] = [ + "OVSentenceTransformer", + ] +else: + _import_structure["openvino"].extend( + [ + "OVSentenceTransformer", + ] + ) + + if TYPE_CHECKING: try: if not is_ipex_available(): @@ -302,6 +318,18 @@ else: from .neural_compressor import INCStableDiffusionPipeline + try: + if not (is_openvino_available() and is_sentence_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_openvino_and_sentence_transformers_objects import ( + OVSentenceTransformer, + ) + else: + from .openvino import ( + OVSentenceTransformer, + ) + else: import sys diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 4ee285f07d..929bdf1be7 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -15,7 +15,12 @@ import logging import warnings -from ..utils.import_utils import is_accelerate_available, is_diffusers_available, is_nncf_available +from ..utils.import_utils import ( + is_accelerate_available, + is_diffusers_available, + is_nncf_available, + is_sentence_transformers_available, +) from .utils import ( OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, @@ -77,3 +82,7 @@ OVStableDiffusionXLImg2ImgPipeline, OVStableDiffusionXLPipeline, ) + + +if is_sentence_transformers_available(): + from .modeling_sentence_transformers import OVSentenceTransformer diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index c41874dc80..786704682f 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -369,6 +369,13 @@ class OVModelForFeatureExtraction(OVModel): auto_model_class = AutoModel def __init__(self, model=None, config=None, **kwargs): + if {"token_embeddings", "sentence_embedding"}.issubset( + {name for output in model.outputs for name in output.names} + ): # Sentence Transormers outputs + raise ValueError( + "This model is a Sentence Transformers model. Please use `OVSentenceTransformer` to load this model." + ) + super().__init__(model, config, **kwargs) @add_start_docstrings_to_model_forward( diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 89da349c82..e8dc113128 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -54,6 +54,7 @@ class OVBaseModel(OptimizedModel): auto_model_class = None export_feature = None _supports_cache_class = False + _library_name = "transformers" def __init__( self, @@ -501,6 +502,7 @@ def _from_transformers( force_download=force_download, trust_remote_code=trust_remote_code, ov_config=ov_config, + library_name=cls._library_name, ) config.save_pretrained(save_dir_path) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index c880883e82..534edd20b4 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -296,6 +296,7 @@ def _from_transformers( ov_config=ov_export_config, stateful=stateful, model_loading_kwargs=model_loading_kwargs, + library_name=cls._library_name, ) config.is_decoder = True diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index f70fe7b26e..a39c9a80b8 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -78,6 +78,7 @@ class OVStableDiffusionPipelineBase(OVBaseModel, OVTextualInversionLoaderMixin): auto_model_class = StableDiffusionPipeline config_name = "model_index.json" export_feature = "text-to-image" + _library_name = "diffusers" def __init__( self, @@ -372,6 +373,7 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, ov_config=ov_config, + library_name=cls._library_name, ) return cls._from_pretrained( diff --git a/optimum/intel/openvino/modeling_sentence_transformers.py b/optimum/intel/openvino/modeling_sentence_transformers.py new file mode 100644 index 0000000000..d523993cf2 --- /dev/null +++ b/optimum/intel/openvino/modeling_sentence_transformers.py @@ -0,0 +1,142 @@ +from pathlib import Path +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer, PretrainedConfig +from transformers.file_utils import add_start_docstrings + +from .modeling import MODEL_START_DOCSTRING, OVModel + + +@add_start_docstrings( + """ + OpenVINO Model for feature extraction tasks for Sentence Transformers. + """, + MODEL_START_DOCSTRING, +) +class OVSentenceTransformer(OVModel): + export_feature = "feature-extraction" + _library_name = "sentence_transformers" + + def __init__(self, model=None, config=None, tokenizer=None, **kwargs): + super().__init__(model, config, **kwargs) + + self.encode = MethodType(SentenceTransformer.encode, self) + self._text_length = MethodType(SentenceTransformer._text_length, self) + self.default_prompt_name = None + self.truncate_dim = None + self.tokenizer = tokenizer + + def _save_pretrained(self, save_directory: Union[str, Path]): + super()._save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + + def forward(self, inputs: Dict[str, torch.Tensor]): + self.compile() + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + token_type_ids = inputs.get("token_type_ids") + + np_inputs = isinstance(input_ids, np.ndarray) + if not np_inputs: + input_ids = np.array(input_ids) + attention_mask = np.array(attention_mask) + token_type_ids = np.array(token_type_ids) if token_type_ids is not None else token_type_ids + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + # Add the token_type_ids when needed + if "token_type_ids" in self.input_names: + inputs["token_type_ids"] = token_type_ids if token_type_ids is not None else np.zeros_like(input_ids) + + outputs = self._inference(inputs) + return { + "token_embeddings": torch.from_numpy(outputs["token_embeddings"]).to(self.device), + "sentence_embedding": torch.from_numpy(outputs["sentence_embedding"]).to(self.device), + } + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: PretrainedConfig, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + file_name: Optional[str] = None, + subfolder: str = "", + from_onnx: bool = False, + local_files_only: bool = False, + **kwargs, + ): + trust_remote_code = kwargs.pop("trust_remote_code", False) + tokenizer_kwargs = kwargs.pop("tokenizer_kwargs", None) + + tokenizer_args = { + "token": token, + "trust_remote_code": trust_remote_code, + "revision": revision, + "local_files_only": local_files_only, + } + if tokenizer_kwargs: + kwargs["tokenizer_args"].update(tokenizer_kwargs) + + tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_args) + + return super()._from_pretrained( + model_id=model_id, + config=config, + token=token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + from_onnx=from_onnx, + local_files_only=local_files_only, + tokenizer=tokenizer, + **kwargs, + ) + + def tokenize( + self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True + ) -> Dict[str, torch.Tensor]: + """Tokenizes a text and maps tokens to token-ids""" + output = {} + if isinstance(texts[0], str): + to_tokenize = [texts] + elif isinstance(texts[0], dict): + to_tokenize = [] + output["text_keys"] = [] + for lookup in texts: + text_key, text = next(iter(lookup.items())) + to_tokenize.append(text) + output["text_keys"].append(text_key) + to_tokenize = [to_tokenize] + else: + batch1, batch2 = [], [] + for text_tuple in texts: + batch1.append(text_tuple[0]) + batch2.append(text_tuple[1]) + to_tokenize = [batch1, batch2] + + # strip + to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] + + output.update( + self.tokenizer( + *to_tokenize, + padding=padding, + truncation="longest_first", + return_tensors="pt", + ) + ) + return output diff --git a/optimum/intel/utils/__init__.py b/optimum/intel/utils/__init__.py index 50cdfa143e..b79deeb62d 100644 --- a/optimum/intel/utils/__init__.py +++ b/optimum/intel/utils/__init__.py @@ -24,6 +24,7 @@ is_nncf_available, is_numa_available, is_openvino_available, + is_sentence_transformers_available, is_torch_version, is_transformers_available, is_transformers_version, diff --git a/optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py b/optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py new file mode 100644 index 0000000000..fd13e5f56a --- /dev/null +++ b/optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py @@ -0,0 +1,26 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .import_utils import DummyObject, requires_backends + + +class OVSentenceTransformer(metaclass=DummyObject): + _backends = ["openvino", "sentence_transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "sentence_transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "sentence_transformers"]) diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 032280e940..8024d23899 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -159,6 +159,16 @@ _numa_available = False +_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None +_sentence_transformers_available = "N/A" + +if _sentence_transformers_available: + try: + _sentence_transformers_available = importlib_metadata.version("sentence_transformers") + except importlib_metadata.PackageNotFoundError: + _sentence_transformers_available = False + + def is_transformers_available(): return _transformers_available @@ -280,6 +290,10 @@ def is_accelerate_available(): return _accelerate_available +def is_sentence_transformers_available(): + return _sentence_transformers_available + + def is_numa_available(): return _numa_available diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 9da496ae05..b95c838815 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -36,6 +36,7 @@ OVModelForSeq2SeqLM, OVModelForSequenceClassification, OVModelForTokenClassification, + OVSentenceTransformer, OVStableDiffusionPipeline, OVStableDiffusionXLPipeline, ) @@ -316,5 +317,5 @@ def test_exporters_cli_sentence_transformers(self): shell=True, check=True, ) - model = eval(_HEAD_TO_AUTOMODELS["feature-extraction"]).from_pretrained(tmpdir, compile=False) + model = OVSentenceTransformer.from_pretrained(tmpdir, compile=False) self.assertFalse("last_hidden_state" in model.output_names) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 97e7290ef0..30c70c7c94 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -77,6 +77,7 @@ OVModelForSpeechSeq2Seq, OVModelForTokenClassification, OVModelForVision2Seq, + OVSentenceTransformer, OVStableDiffusionPipeline, ) from optimum.intel.openvino import OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, OV_ENCODER_NAME, OV_XML_FILE_NAME @@ -655,6 +656,20 @@ def test_pipeline(self, model_arch): del model gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_sentence_transformers_pipeline(self, model_arch): + """ + Check if we call OVModelForFeatureExtraction passing saved ir-model with outputs + from Sentence Transformers then an appropriate exception raises. + """ + model_id = MODEL_NAMES[model_arch] + with tempfile.TemporaryDirectory() as tmp_dir: + save_dir = str(tmp_dir) + OVSentenceTransformer.from_pretrained(model_id, export=True).save_pretrained(save_dir) + with self.assertRaises(Exception) as context: + OVModelForFeatureExtraction.from_pretrained(save_dir) + self.assertIn("Please use `OVSentenceTransformer`", str(context.exception)) + class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( diff --git a/tests/openvino/test_modeling_sentence_transformers.py b/tests/openvino/test_modeling_sentence_transformers.py new file mode 100644 index 0000000000..acda045123 --- /dev/null +++ b/tests/openvino/test_modeling_sentence_transformers.py @@ -0,0 +1,74 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized +from sentence_transformers import SentenceTransformer +from transformers import ( + PretrainedConfig, + set_seed, +) + +from optimum.intel import OVSentenceTransformer + + +SEED = 42 + +F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"} + +MODEL_NAMES = { + "bert": "sentence-transformers/all-MiniLM-L6-v2", + "mpnet": "sentence-transformers/all-mpnet-base-v2", +} + + +class OVModelForSTFeatureExtractionIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( + "bert", + "mpnet", + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + ov_model = OVSentenceTransformer.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) + self.assertIsInstance(ov_model.config, PretrainedConfig) + self.assertTrue(hasattr(ov_model, "encode")) + st_model = SentenceTransformer(model_id) + sentences = ["This is an example sentence", "Each sentence is converted"] + st_embeddings = st_model.encode(sentences) + ov_embeddings = ov_model.encode(sentences) + # Compare tensor outputs + self.assertTrue(np.allclose(ov_embeddings, st_embeddings, atol=1e-4)) + del st_embeddings + del ov_model + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_sentence_transformers_save_and_infer(self, model_arch): + model_id = MODEL_NAMES[model_arch] + ov_model = OVSentenceTransformer.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) + with tempfile.TemporaryDirectory() as tmpdirname: + model_save_path = os.path.join(tmpdirname, "sentence_transformers_ov_model") + ov_model.save_pretrained(model_save_path) + model = OVSentenceTransformer.from_pretrained(model_save_path) + sentences = ["This is an example sentence", "Each sentence is converted"] + model.encode(sentences) + gc.collect()