-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support SentenceTransformers models in optimum (#865)
* Support sentence transformer * FIx error * Separate class for SentenceTransfromers and test * Separate class interface for Sentence Transformers * Check if object has encode method * Formatting and change uotput checking * Tokenizer init moved to __init__, sentence-tranformer pipeline in OVModelForFeatureExtraction was changed * Remove model_max_length default value * Update tests/openvino/test_modeling_sentence_transformers.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Update tests/openvino/test_modeling.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Move tokenizer initialization, other improvements * Update optimum/intel/openvino/modeling_sentence_transformers.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Renaming OVModelForSentenceTransformer to OVSentenceTransformer * Make style * Move checking to init * Update tests/openvino/test_modeling.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Add dummy_openvino_and_sentence_transformers_objects.py * refactoring * Update optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Add tests to check saving and loading model * Add tests to check saving and loading model * Support cls._library_name * Fix test * Refactoring * Refactoring fix * Update optimum/intel/openvino/modeling.py * Fix test --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Co-authored-by: Ella Charlaix <ella@huggingface.co>
- Loading branch information
1 parent
3586b5b
commit 4dc4d57
Showing
13 changed files
with
324 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
optimum/intel/openvino/modeling_sentence_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from pathlib import Path | ||
from types import MethodType | ||
from typing import Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE | ||
from sentence_transformers import SentenceTransformer | ||
from transformers import AutoTokenizer, PretrainedConfig | ||
from transformers.file_utils import add_start_docstrings | ||
|
||
from .modeling import MODEL_START_DOCSTRING, OVModel | ||
|
||
|
||
@add_start_docstrings( | ||
""" | ||
OpenVINO Model for feature extraction tasks for Sentence Transformers. | ||
""", | ||
MODEL_START_DOCSTRING, | ||
) | ||
class OVSentenceTransformer(OVModel): | ||
export_feature = "feature-extraction" | ||
_library_name = "sentence_transformers" | ||
|
||
def __init__(self, model=None, config=None, tokenizer=None, **kwargs): | ||
super().__init__(model, config, **kwargs) | ||
|
||
self.encode = MethodType(SentenceTransformer.encode, self) | ||
self._text_length = MethodType(SentenceTransformer._text_length, self) | ||
self.default_prompt_name = None | ||
self.truncate_dim = None | ||
self.tokenizer = tokenizer | ||
|
||
def _save_pretrained(self, save_directory: Union[str, Path]): | ||
super()._save_pretrained(save_directory) | ||
self.tokenizer.save_pretrained(save_directory) | ||
|
||
def forward(self, inputs: Dict[str, torch.Tensor]): | ||
self.compile() | ||
input_ids = inputs.get("input_ids") | ||
attention_mask = inputs.get("attention_mask") | ||
token_type_ids = inputs.get("token_type_ids") | ||
|
||
np_inputs = isinstance(input_ids, np.ndarray) | ||
if not np_inputs: | ||
input_ids = np.array(input_ids) | ||
attention_mask = np.array(attention_mask) | ||
token_type_ids = np.array(token_type_ids) if token_type_ids is not None else token_type_ids | ||
|
||
inputs = { | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
} | ||
|
||
# Add the token_type_ids when needed | ||
if "token_type_ids" in self.input_names: | ||
inputs["token_type_ids"] = token_type_ids if token_type_ids is not None else np.zeros_like(input_ids) | ||
|
||
outputs = self._inference(inputs) | ||
return { | ||
"token_embeddings": torch.from_numpy(outputs["token_embeddings"]).to(self.device), | ||
"sentence_embedding": torch.from_numpy(outputs["sentence_embedding"]).to(self.device), | ||
} | ||
|
||
@classmethod | ||
def _from_pretrained( | ||
cls, | ||
model_id: Union[str, Path], | ||
config: PretrainedConfig, | ||
token: Optional[Union[bool, str]] = None, | ||
revision: Optional[str] = None, | ||
force_download: bool = False, | ||
cache_dir: str = HUGGINGFACE_HUB_CACHE, | ||
file_name: Optional[str] = None, | ||
subfolder: str = "", | ||
from_onnx: bool = False, | ||
local_files_only: bool = False, | ||
**kwargs, | ||
): | ||
trust_remote_code = kwargs.pop("trust_remote_code", False) | ||
tokenizer_kwargs = kwargs.pop("tokenizer_kwargs", None) | ||
|
||
tokenizer_args = { | ||
"token": token, | ||
"trust_remote_code": trust_remote_code, | ||
"revision": revision, | ||
"local_files_only": local_files_only, | ||
} | ||
if tokenizer_kwargs: | ||
kwargs["tokenizer_args"].update(tokenizer_kwargs) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_args) | ||
|
||
return super()._from_pretrained( | ||
model_id=model_id, | ||
config=config, | ||
token=token, | ||
revision=revision, | ||
force_download=force_download, | ||
cache_dir=cache_dir, | ||
file_name=file_name, | ||
subfolder=subfolder, | ||
from_onnx=from_onnx, | ||
local_files_only=local_files_only, | ||
tokenizer=tokenizer, | ||
**kwargs, | ||
) | ||
|
||
def tokenize( | ||
self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True | ||
) -> Dict[str, torch.Tensor]: | ||
"""Tokenizes a text and maps tokens to token-ids""" | ||
output = {} | ||
if isinstance(texts[0], str): | ||
to_tokenize = [texts] | ||
elif isinstance(texts[0], dict): | ||
to_tokenize = [] | ||
output["text_keys"] = [] | ||
for lookup in texts: | ||
text_key, text = next(iter(lookup.items())) | ||
to_tokenize.append(text) | ||
output["text_keys"].append(text_key) | ||
to_tokenize = [to_tokenize] | ||
else: | ||
batch1, batch2 = [], [] | ||
for text_tuple in texts: | ||
batch1.append(text_tuple[0]) | ||
batch2.append(text_tuple[1]) | ||
to_tokenize = [batch1, batch2] | ||
|
||
# strip | ||
to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] | ||
|
||
output.update( | ||
self.tokenizer( | ||
*to_tokenize, | ||
padding=padding, | ||
truncation="longest_first", | ||
return_tensors="pt", | ||
) | ||
) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2023 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .import_utils import DummyObject, requires_backends | ||
|
||
|
||
class OVSentenceTransformer(metaclass=DummyObject): | ||
_backends = ["openvino", "sentence_transformers"] | ||
|
||
def __init__(self, *args, **kwargs): | ||
requires_backends(self, ["openvino", "sentence_transformers"]) | ||
|
||
@classmethod | ||
def from_pretrained(cls, *args, **kwargs): | ||
requires_backends(cls, ["openvino", "sentence_transformers"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.