Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support transformers 4.47 #1088

Merged
merged 11 commits into from
Jan 6, 2025
18 changes: 16 additions & 2 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2712,7 +2712,14 @@ def patched_forward(*args, **kwargs):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the mismatch for qwen2_vl is coming from https://github.com/huggingface/transformers/pull/33487/files (mismatch between the position_ids when not given at inference)

def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.45.0"):

if is_transformers_version(">=", "4.47.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION

GEMMA2_ATTENTION_FUNCTION["original_eager"] = GEMMA2_ATTENTION_FUNCTION["eager"]
GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["sdpa"]
echarlaix marked this conversation as resolved.
Show resolved Hide resolved

elif is_transformers_version(">=", "4.45.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES

sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
Expand All @@ -2725,7 +2732,14 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.45.0"):

if is_transformers_version(">=", "4.47.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION

GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["original_eager"]
del GEMMA2_ATTENTION_FUNCTION["original_eager"]

elif is_transformers_version(">=", "4.45.0"):
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
INSTALL_REQUIRE = [
"torch>=1.11",
"optimum@git+https://github.com/huggingface/optimum.git",
"transformers>=4.36,<4.47",
"transformers>=4.36,<4.48",
"datasets>=1.4.0",
"sentencepiece",
"setuptools",
Expand Down
Loading