Skip to content

Commit

Permalink
fix conversion for text embeddings for fp16 models
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 24, 2024
1 parent 86598a6 commit 2d0779c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ def ts_patched_forward(*args, **kwargs):
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

# frontend may riases confusing warnings about modules already patched if model splitted on several parts
logging.disable(logging.WARNING)
__make_16bit_traceable(model)
logging.disable(logging.NOTSET)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ov_model = convert_model(
Expand Down
6 changes: 6 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
IBertModelPatcher,
InputEmbeddingPatcher,
InternLM2Patcher,
InternLMModelPatcher,
InternVLChatImageEmbeddingModelPatcher,
Expand Down Expand Up @@ -1261,6 +1262,11 @@ def rename_ambiguous_inputs(self, inputs):
model_inputs["input"] = inputs["input_ids"]
return model_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return InputEmbeddingPatcher(self, model, model_kwargs)


class LlavaConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
Expand Down
22 changes: 22 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,3 +2743,25 @@ def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.pos_embed, "_orig_forward"):
self._model.pos_embed.forward = self._model.pos_embed._orig_forward


class InputEmbeddingPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
# making 16bit tracable overrides embeedings input signature these changes required to prevent this issue
model.__orig_forward = model.forward

def forward(self, input):
return self.__orig_forward(input)

model.forward = types.MethodType(forward, model)

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward

0 comments on commit 2d0779c

Please sign in to comment.