Skip to content

Commit 8405675

Browse files
committed
fix rotary emb initialization
1 parent 50c417e commit 8405675

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
1717

1818
from packaging import version
19+
from transformers import PreTrainedModel, TFPreTrainedModel
1920
from transformers.utils import is_tf_available
2021

2122
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2223
from optimum.exporters.onnx.model_configs import (
2324
CodeGenOnnxConfig,
2425
FalconOnnxConfig,
2526
GemmaOnnxConfig,
27+
GPTNeoXOnnxConfig,
2628
LlamaOnnxConfig,
2729
MistralOnnxConfig,
2830
MPTOnnxConfig,
@@ -31,6 +33,7 @@
3133
VaeDecoderOnnxConfig,
3234
VaeEncoderOnnxConfig,
3335
)
36+
from optimum.exporters.onnx.model_patcher import ModelPatcher
3437
from optimum.exporters.tasks import TasksManager
3538
from optimum.utils import DEFAULT_DUMMY_SHAPES
3639
from optimum.utils.input_generators import (
@@ -50,6 +53,9 @@
5053
ChatGLMModelPatcher,
5154
CodeGenModelPatcher,
5255
DBRXModelPatcher,
56+
FalconModelPatcher,
57+
GptNeoxJapaneseModelPatcher,
58+
GptNeoxModelPatcher,
5359
InternLM2Patcher,
5460
InternLMModelPatcher,
5561
JaisModelPatcher,
@@ -60,6 +66,7 @@
6066
PersimmonModelPatcher,
6167
Phi3ModelPatcher,
6268
QwenModelPatcher,
69+
RotaryEmbPatcher,
6370
XverseModelPatcher,
6471
)
6572

@@ -119,6 +126,11 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
119126
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
120127
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
121128

129+
def patch_model_for_export(
130+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
131+
) -> "ModelPatcher":
132+
return RotaryEmbPatcher(self, model, model_kwargs=model_kwargs)
133+
122134

123135
@register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers")
124136
class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
@@ -128,6 +140,11 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
128140
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
129141
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
130142

143+
def patch_model_for_export(
144+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
145+
) -> "ModelPatcher":
146+
return RotaryEmbPatcher(self, model, model_kwargs=model_kwargs)
147+
131148

132149
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers")
133150
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
@@ -146,6 +163,11 @@ class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
146163
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
147164
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
148165

166+
def patch_model_for_export(
167+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
168+
) -> "ModelPatcher":
169+
return RotaryEmbPatcher(self, model, model_kwargs=model_kwargs)
170+
149171

150172
class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
151173
def __init__(
@@ -469,6 +491,12 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
469491
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
470492

471493

494+
def patch_model_for_export(
495+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
496+
) -> "ModelPatcher":
497+
return RotaryEmbPatcher(self, model, model_kwargs=model_kwargs)
498+
499+
472500
@register_in_tasks_manager("internlm2", *["text-generation", "text-generation-with-past"], library_name="transformers")
473501
class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
474502
DEFAULT_ONNX_OPSET = 14
@@ -508,6 +536,24 @@ def patch_model_for_export(
508536
return MPTModelPatcher(self, model, model_kwargs=model_kwargs)
509537

510538

539+
@register_in_tasks_manager(
540+
"phi",
541+
*[
542+
"feature-extraction",
543+
"feature-extraction-with-past",
544+
"text-generation",
545+
"text-generation-with-past",
546+
"text-classification",
547+
],
548+
library_name="transformers",
549+
)
550+
class PhiOpenVINOConfig(PhiOnnxConfig):
551+
def patch_model_for_export(
552+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
553+
) -> "ModelPatcher":
554+
return RotaryEmbPatcher(self, model, model_kwargs=model_kwargs)
555+
556+
511557
@register_in_tasks_manager(
512558
"phi3",
513559
*[
@@ -578,6 +624,11 @@ class FalconOpenVINOConfig(FalconOnnxConfig):
578624
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
579625
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator
580626

627+
def patch_model_for_export(
628+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
629+
) -> "ModelPatcher":
630+
return FalconModelPatcher(self, model, model_kwargs=model_kwargs)
631+
581632

582633
@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
583634
class UNetOpenVINOConfig(UNetOnnxConfig):
@@ -671,6 +722,11 @@ class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig):
671722
DEFAULT_ONNX_OPSET = 13
672723
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
673724

725+
def patch_model_for_export(
726+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
727+
) -> "ModelPatcher":
728+
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)
729+
674730

675731
@register_in_tasks_manager(
676732
"cohere",
@@ -859,3 +915,21 @@ def patch_model_for_export(
859915
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
860916
) -> "ModelPatcher":
861917
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)
918+
919+
920+
@register_in_tasks_manager(
921+
"gpt-neox",
922+
*[
923+
"feature-extraction",
924+
"feature-extraction-with-past",
925+
"text-generation",
926+
"text-generation-with-past",
927+
"text-classification",
928+
],
929+
library_name="transformers",
930+
)
931+
class GPTNeoxOpenVINOConfig(GPTNeoXOnnxConfig):
932+
def patch_model_for_export(
933+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
934+
) -> "ModelPatcher":
935+
return GptNeoxModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def patch_model_with_bettertransformer(model):
101101
return model
102102

103103

104+
# initialization of sin/cos cached in bf16/fp16 leads to accuracy loss
105+
# reinitialize them to save in float32 before export
106+
def _reinitialize_cos_sin_cached_fp32(rotary_emb):
107+
if rotary_emb.cos_cached.dtype != torch.float32:
108+
rotary_emb._set_cos_sin_cache(
109+
seq_len=rotary_emb.max_position_embeddings, device=rotary_emb.inv_freq.device, dtype=torch.float32
110+
)
111+
112+
104113
def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
105114
""" """
106115
batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -149,6 +158,7 @@ def __enter__(self):
149158
layer.block_sparse_moe.forward = types.MethodType(
150159
_mixtral_sparse_moe_block_forward, layer.block_sparse_moe
151160
)
161+
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
152162

153163
def __exit__(self, exc_type, exc_value, traceback):
154164
super().__exit__(exc_type, exc_value, traceback)
@@ -687,6 +697,9 @@ def __enter__(self):
687697
# mistral has some accuracy issues with bf16 with transformers >= 4.42
688698
# prefill rotary emb sin/cos for avoid this issue
689699
register_sin_cos_buffer(self._model)
700+
else:
701+
for layer in self._model.model.layers:
702+
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
690703

691704
def __exit__(self, exc_type, exc_value, traceback):
692705
super().__exit__(exc_type, exc_value, traceback)
@@ -2094,6 +2107,7 @@ def __enter__(self):
20942107
orig_self_attn_fwd = layer.self_attn.forward
20952108
layer.self_attn.forward = types.MethodType(_persimmon_self_attn_sdpa_forward, layer.self_attn)
20962109
layer.self_attn._orig_forward = orig_self_attn_fwd
2110+
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
20972111

20982112
def __exit__(self, exc_type, exc_value, traceback):
20992113
super().__exit__(exc_type, exc_value, traceback)
@@ -2221,3 +2235,31 @@ def __exit__(self, exc_type, exc_value, traceback):
22212235
if hasattr(layer.attn, "_orig_attn"):
22222236
layer.attn._attn = layer.attn._orig_attn
22232237
layer.attn.forward = layer.attn._orig_forward
2238+
2239+
2240+
class RotaryEmbPatcher(DecoderModelPatcher):
2241+
def __enter__(self):
2242+
super().__enter__()
2243+
for layer in self._model.model.layers:
2244+
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
2245+
2246+
2247+
class FalconModelPatcher(DecoderModelPatcher):
2248+
def __enter__(self):
2249+
super().__enter__()
2250+
for layer in self._model.transformer.h:
2251+
_reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb)
2252+
2253+
2254+
class GptNeoxModelPatcher(DecoderModelPatcher):
2255+
def __enter__(self):
2256+
super().__enter__()
2257+
for layer in self._model.gpt_neox.layers:
2258+
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
2259+
2260+
2261+
class GptNeoxJapaneseModelPatcher(DecoderModelPatcher):
2262+
def __enter__(self):
2263+
super().__enter__()
2264+
for layer in self._model.gpt_neox_japanese.layers:
2265+
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)

0 commit comments

Comments
 (0)