Skip to content

Commit

Permalink
Apply sdpa for mpt and internlm (#676)
Browse files Browse the repository at this point in the history
* apply sdpa for mpt and internlm

* fix bauchan-13b

* fix accuracy

* small refactoring

* add test for baichuan 13b

* add support output_attentions

* code style
  • Loading branch information
eaidova authored Apr 30, 2024
1 parent b017856 commit e1b6a59
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 3 deletions.
4 changes: 2 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def ts_patched_forward(*args, **kwargs):

with patcher:
check_dummy_inputs_are_allowed(model, dummy_inputs)
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())
Expand Down Expand Up @@ -387,7 +388,6 @@ def ts_patched_forward(*args, **kwargs):
ov_config=ov_config,
)

sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
if not ordered_dummy_inputs:
ordered_dummy_inputs = dummy_inputs
Expand All @@ -403,7 +403,7 @@ def ts_patched_forward(*args, **kwargs):
inp_tensor.get_tensor().set_names({input_name})
inp_data = flatten_inputs[idx]
static_shape = PartialShape(inp_data.shape)
dims = inputs[input_name]
dims = inputs.get(input_name, [])
for dim in dims:
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
FalconOnnxConfig,
GemmaOnnxConfig,
LlamaOnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
UNetOnnxConfig,
VaeDecoderOnnxConfig,
Expand All @@ -43,8 +44,10 @@
BaichuanModelPatcher,
ChatGLMModelPatcher,
GemmaModelPatcher,
InternLMPatcher,
LlamaModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
)
Expand Down Expand Up @@ -439,6 +442,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return InternLMPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers")
class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand All @@ -455,6 +463,16 @@ class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager(
"mpt", *["text-generation", "text-generation-with-past", "text-classification"], library_name="transformers"
)
class MPTOpenVINOConfig(MPTOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MPTModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"phi3",
*[
Expand Down
Loading

0 comments on commit e1b6a59

Please sign in to comment.