From b9c9736ae2d64fdbe9c5cf8d692a003a3fa8ec42 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 8 Mar 2024 05:29:33 -0500 Subject: [PATCH] add output name to ipex model --- optimum/intel/ipex/modeling_base.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 608afa0805..9928977ead 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -58,6 +58,7 @@ class IPEXModel(OptimizedModel): export_feature = "feature-extraction" base_model_prefix = "ipex_model" main_input_name = "input_ids" + output_name = "last_hidden_state" def __init__( self, @@ -193,7 +194,12 @@ def forward( inputs["token_type_ids"] = token_type_ids outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(last_hidden_state=outputs[0]) + if isinstance(outputs, dict): + model_output = ModelOutput(**outputs) + else: + model_output = ModelOutput() + model_output[self.output_name] = outputs[0] + return model_output def eval(self): self.model.eval() @@ -235,16 +241,19 @@ def _init_warmup(self): class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification export_feature = "text-classification" + output_name = "logits" class IPEXModelForTokenClassification(IPEXModel): auto_model_class = AutoModelForTokenClassification export_feature = "token-classification" + output_name = "logits" class IPEXModelForMaskedLM(IPEXModel): auto_model_class = AutoModelForMaskedLM export_feature = "fill-mask" + output_name = "logits" class IPEXModelForImageClassification(IPEXModel):