Skip to content

Commit

Permalink
fix sentence transformers export
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Feb 13, 2024
1 parent a9a235b commit b2751dc
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
5 changes: 4 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def main_export(
compression_ratio: Optional[float] = None,
stateful: bool = True,
convert_tokenizer: bool = False,
library_name: Optional[str] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -139,7 +140,9 @@ def main_export(
original_task = task
task = TasksManager.map_from_synonym(task)
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
library_name = TasksManager.infer_library_from_model(model_name_or_path, subfolder=subfolder)
library_name = TasksManager.infer_library_from_model(
model_name_or_path, subfolder=subfolder, library_name=library_name
)

if task == "auto":
try:
Expand Down
1 change: 0 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ def export_pytorch(
try:
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())
Expand Down
44 changes: 44 additions & 0 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union

import numpy as np
Expand Down Expand Up @@ -50,6 +51,7 @@

from optimum.exporters import TasksManager

from ...exporters.openvino import main_export
from ..utils.import_utils import is_timm_available, is_timm_version
from .modeling_base import OVBaseModel
from .utils import _is_timm_ov_dir
Expand Down Expand Up @@ -411,6 +413,48 @@ def forward(
)
return BaseModelOutput(last_hidden_state=last_hidden_state)

@classmethod
def _from_transformers(
cls,
model_id: str,
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
task: Optional[str] = None,
trust_remote_code: bool = False,
load_in_8bit: Optional[bool] = None,
load_in_4bit: Optional[bool] = None,
**kwargs,
):
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)

# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
compression_option = "fp32" if load_in_8bit is not None else None

# OVModelForFeatureExtraction works with Transformers type of models, thus even sentence-transformers models are loaded as such.
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task or cls.export_feature,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
compression_option=compression_option,
library_name="transformers",
)

config.save_pretrained(save_dir_path)
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs)


MASKED_LM_EXAMPLE = r"""
Example of masked language modeling using `transformers.pipelines`:
Expand Down

0 comments on commit b2751dc

Please sign in to comment.