From ce7d8869944554b058771943b182e665c29e2ae9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 26 Nov 2024 14:39:20 +0000 Subject: [PATCH] simplify forward and save pretrained since no jit support --- optimum/intel/ipex/modeling_base.py | 83 ++--------------------------- 1 file changed, 3 insertions(+), 80 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 19971dacf8..e9e5b98527 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -233,39 +233,10 @@ def _from_pretrained( return cls(model, config=config, export=True, **kwargs) def _save_pretrained(self, save_directory: Union[str, Path]): - if getattr(self.config, "torchscript", None): - output_path = os.path.join(save_directory, WEIGHTS_NAME) - torch.jit.save(self.model, output_path) - else: - logger.warning("The module is not a torchscript model, will be treated as a transformers model.") - self.model.save_pretrained(save_directory, safe_serialization=False) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - position_ids: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "token_type_ids" in self.input_names: - inputs["token_type_ids"] = token_type_ids + self.model.save_pretrained(save_directory, safe_serialization=False) - if "position_ids" in self.input_names: - inputs["position_ids"] = position_ids - - outputs = self._call_model(**inputs) - if isinstance(outputs, dict): - model_output = ModelOutput(**outputs) - else: - model_output = ModelOutput() - model_output[self.output_name] = outputs[0] - return model_output + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) def eval(self): self.model.eval() @@ -336,64 +307,16 @@ class IPEXModelForImageClassification(IPEXModel): auto_model_class = AutoModelForImageClassification export_feature = "image-classification" - def forward( - self, - pixel_values: torch.Tensor, - **kwargs, - ): - inputs = { - "pixel_values": pixel_values, - } - - outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) - class IPEXModelForAudioClassification(IPEXModel): auto_model_class = AutoModelForAudioClassification export_feature = "audio-classification" - def forward( - self, - input_values: torch.Tensor, - attention_mask: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_values": input_values, - } - - if "attention_mask" in self.input_names: - inputs["attention_mask"] = attention_mask - - outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) - class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "token_type_ids" in self.input_names: - inputs["token_type_ids"] = token_type_ids - - outputs = self._call_model(**inputs) - start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] - end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] - return ModelOutput(start_logits=start_logits, end_logits=end_logits) - class IPEXModelForCausalLM(IPEXModel, GenerationMixin): auto_model_class = AutoModelForCausalLM