Skip to content

Commit

Permalink
add output name to ipex model
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Mar 8, 2024
1 parent ac809f8 commit b9c9736
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class IPEXModel(OptimizedModel):
export_feature = "feature-extraction"
base_model_prefix = "ipex_model"
main_input_name = "input_ids"
output_name = "last_hidden_state"

def __init__(
self,
Expand Down Expand Up @@ -193,7 +194,12 @@ def forward(
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(last_hidden_state=outputs[0])
if isinstance(outputs, dict):
model_output = ModelOutput(**outputs)
else:
model_output = ModelOutput()
model_output[self.output_name] = outputs[0]
return model_output

def eval(self):
self.model.eval()
Expand Down Expand Up @@ -235,16 +241,19 @@ def _init_warmup(self):
class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
export_feature = "text-classification"
output_name = "logits"


class IPEXModelForTokenClassification(IPEXModel):
auto_model_class = AutoModelForTokenClassification
export_feature = "token-classification"
output_name = "logits"


class IPEXModelForMaskedLM(IPEXModel):
auto_model_class = AutoModelForMaskedLM
export_feature = "fill-mask"
output_name = "logits"


class IPEXModelForImageClassification(IPEXModel):
Expand Down

0 comments on commit b9c9736

Please sign in to comment.