Skip to content

Commit

Permalink
simplify forward and save pretrained since no jit support
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Nov 26, 2024
1 parent 8a8e7e3 commit ce7d886
Showing 1 changed file with 3 additions and 80 deletions.
83 changes: 3 additions & 80 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,39 +233,10 @@ def _from_pretrained(
return cls(model, config=config, export=True, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
if getattr(self.config, "torchscript", None):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
torch.jit.save(self.model, output_path)
else:
logger.warning("The module is not a torchscript model, will be treated as a transformers model.")
self.model.save_pretrained(save_directory, safe_serialization=False)

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
position_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids
self.model.save_pretrained(save_directory, safe_serialization=False)

if "position_ids" in self.input_names:
inputs["position_ids"] = position_ids

outputs = self._call_model(**inputs)
if isinstance(outputs, dict):
model_output = ModelOutput(**outputs)
else:
model_output = ModelOutput()
model_output[self.output_name] = outputs[0]
return model_output
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def eval(self):
self.model.eval()
Expand Down Expand Up @@ -336,64 +307,16 @@ class IPEXModelForImageClassification(IPEXModel):
auto_model_class = AutoModelForImageClassification
export_feature = "image-classification"

def forward(
self,
pixel_values: torch.Tensor,
**kwargs,
):
inputs = {
"pixel_values": pixel_values,
}

outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForAudioClassification(IPEXModel):
auto_model_class = AutoModelForAudioClassification
export_feature = "audio-classification"

def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_values": input_values,
}

if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForQuestionAnswering(IPEXModel):
auto_model_class = AutoModelForQuestionAnswering
export_feature = "question-answering"

def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs,
):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
return ModelOutput(start_logits=start_logits, end_logits=end_logits)


class IPEXModelForCausalLM(IPEXModel, GenerationMixin):
auto_model_class = AutoModelForCausalLM
Expand Down

0 comments on commit ce7d886

Please sign in to comment.