16
16
from typing import TYPE_CHECKING , Any , Dict , Optional , Tuple , Union
17
17
18
18
from packaging import version
19
+ from transformers import PreTrainedModel , TFPreTrainedModel
19
20
from transformers .utils import is_tf_available
20
21
21
22
from optimum .exporters .onnx .config import TextDecoderOnnxConfig , TextDecoderWithPositionIdsOnnxConfig
22
23
from optimum .exporters .onnx .model_configs import (
23
24
CodeGenOnnxConfig ,
24
25
FalconOnnxConfig ,
25
26
GemmaOnnxConfig ,
27
+ GPTNeoXOnnxConfig ,
26
28
LlamaOnnxConfig ,
27
29
MistralOnnxConfig ,
28
30
MPTOnnxConfig ,
31
33
VaeDecoderOnnxConfig ,
32
34
VaeEncoderOnnxConfig ,
33
35
)
36
+ from optimum .exporters .onnx .model_patcher import ModelPatcher
34
37
from optimum .exporters .tasks import TasksManager
35
38
from optimum .utils import DEFAULT_DUMMY_SHAPES
36
39
from optimum .utils .input_generators import (
50
53
ChatGLMModelPatcher ,
51
54
CodeGenModelPatcher ,
52
55
DBRXModelPatcher ,
56
+ FalconModelPatcher ,
57
+ GptNeoxJapaneseModelPatcher ,
58
+ GptNeoxModelPatcher ,
53
59
InternLM2Patcher ,
54
60
InternLMModelPatcher ,
55
61
JaisModelPatcher ,
60
66
PersimmonModelPatcher ,
61
67
Phi3ModelPatcher ,
62
68
QwenModelPatcher ,
69
+ RotaryEmbPatcher ,
63
70
XverseModelPatcher ,
64
71
)
65
72
@@ -119,6 +126,11 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
119
126
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
120
127
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
121
128
129
+ def patch_model_for_export (
130
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
131
+ ) -> "ModelPatcher" :
132
+ return RotaryEmbPatcher (self , model , model_kwargs = model_kwargs )
133
+
122
134
123
135
@register_in_tasks_manager ("qwen2-moe" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
124
136
class Qwen2MoEOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
@@ -128,6 +140,11 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
128
140
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
129
141
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
130
142
143
+ def patch_model_for_export (
144
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
145
+ ) -> "ModelPatcher" :
146
+ return RotaryEmbPatcher (self , model , model_kwargs = model_kwargs )
147
+
131
148
132
149
@register_in_tasks_manager ("minicpm" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
133
150
class MiniCPMOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
@@ -146,6 +163,11 @@ class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
146
163
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
147
164
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
148
165
166
+ def patch_model_for_export (
167
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
168
+ ) -> "ModelPatcher" :
169
+ return RotaryEmbPatcher (self , model , model_kwargs = model_kwargs )
170
+
149
171
150
172
class ChatGLM2DummyPastKeyValuesGenerator (DummyPastKeyValuesGenerator ):
151
173
def __init__ (
@@ -469,6 +491,12 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
469
491
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
470
492
471
493
494
+ def patch_model_for_export (
495
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
496
+ ) -> "ModelPatcher" :
497
+ return RotaryEmbPatcher (self , model , model_kwargs = model_kwargs )
498
+
499
+
472
500
@register_in_tasks_manager ("internlm2" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
473
501
class InternLM2OpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
474
502
DEFAULT_ONNX_OPSET = 14
@@ -508,6 +536,24 @@ def patch_model_for_export(
508
536
return MPTModelPatcher (self , model , model_kwargs = model_kwargs )
509
537
510
538
539
+ @register_in_tasks_manager (
540
+ "phi" ,
541
+ * [
542
+ "feature-extraction" ,
543
+ "feature-extraction-with-past" ,
544
+ "text-generation" ,
545
+ "text-generation-with-past" ,
546
+ "text-classification" ,
547
+ ],
548
+ library_name = "transformers" ,
549
+ )
550
+ class PhiOpenVINOConfig (PhiOnnxConfig ):
551
+ def patch_model_for_export (
552
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
553
+ ) -> "ModelPatcher" :
554
+ return RotaryEmbPatcher (self , model , model_kwargs = model_kwargs )
555
+
556
+
511
557
@register_in_tasks_manager (
512
558
"phi3" ,
513
559
* [
@@ -578,6 +624,11 @@ class FalconOpenVINOConfig(FalconOnnxConfig):
578
624
) + TextDecoderOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES
579
625
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator
580
626
627
+ def patch_model_for_export (
628
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
629
+ ) -> "ModelPatcher" :
630
+ return FalconModelPatcher (self , model , model_kwargs = model_kwargs )
631
+
581
632
582
633
@register_in_tasks_manager ("unet" , * ["semantic-segmentation" ], library_name = "diffusers" )
583
634
class UNetOpenVINOConfig (UNetOnnxConfig ):
@@ -671,6 +722,11 @@ class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig):
671
722
DEFAULT_ONNX_OPSET = 13
672
723
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
673
724
725
+ def patch_model_for_export (
726
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
727
+ ) -> "ModelPatcher" :
728
+ return GptNeoxJapaneseModelPatcher (self , model , model_kwargs = model_kwargs )
729
+
674
730
675
731
@register_in_tasks_manager (
676
732
"cohere" ,
@@ -859,3 +915,21 @@ def patch_model_for_export(
859
915
self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
860
916
) -> "ModelPatcher" :
861
917
return MistralModelPatcher (self , model , model_kwargs = model_kwargs )
918
+
919
+
920
+ @register_in_tasks_manager (
921
+ "gpt-neox" ,
922
+ * [
923
+ "feature-extraction" ,
924
+ "feature-extraction-with-past" ,
925
+ "text-generation" ,
926
+ "text-generation-with-past" ,
927
+ "text-classification" ,
928
+ ],
929
+ library_name = "transformers" ,
930
+ )
931
+ class GPTNeoxOpenVINOConfig (GPTNeoXOnnxConfig ):
932
+ def patch_model_for_export (
933
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
934
+ ) -> "ModelPatcher" :
935
+ return GptNeoxModelPatcher (self , model , model_kwargs = model_kwargs )
0 commit comments