Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply sdpa for mpt and internlm #676

Merged
merged 12 commits into from
Apr 30, 2024
4 changes: 2 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def ts_patched_forward(*args, **kwargs):

with patcher:
check_dummy_inputs_are_allowed(model, dummy_inputs)
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())
Expand Down Expand Up @@ -376,7 +377,6 @@ def ts_patched_forward(*args, **kwargs):
ov_config=ov_config,
)

sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
if not ordered_dummy_inputs:
ordered_dummy_inputs = dummy_inputs
Expand All @@ -392,7 +392,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_tensor().set_names({input_name})
inp_data = flatten_inputs[idx]
static_shape = PartialShape(inp_data.shape)
dims = inputs[input_name]
dims = inputs.get(input_name, [])
for dim in dims:
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
FalconOnnxConfig,
GemmaOnnxConfig,
LlamaOnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
UNetOnnxConfig,
VaeDecoderOnnxConfig,
Expand All @@ -43,8 +44,10 @@
BaichuanModelPatcher,
ChatGLMModelPatcher,
GemmaModelPatcher,
InternLMPatcher,
LlamaModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
)
Expand Down Expand Up @@ -439,6 +442,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return InternLMPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand All @@ -455,6 +463,16 @@ class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager(
"mpt", *["text-generation", "text-generation-with-past", "text-classification"], library_name="transformers"
)
class MPTOpenVINOConfig(MPTOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MPTModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"phi3",
*[
Expand Down
Loading
Loading