Skip to content

Commit

Permalink
update testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Mar 7, 2024
1 parent 1b8d76a commit ac809f8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def forward(
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(last_hidden_state=outputs[0])

def eval(self):
self.model.eval()
Expand Down Expand Up @@ -282,7 +282,7 @@ def forward(
inputs["attention_mask"] = attention_mask

outputs = self._call_model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(last_hidden_state=outputs[0])
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


class IPEXModelForQuestionAnswering(IPEXModel):
Expand Down
7 changes: 4 additions & 3 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def test_compare_to_transformers(self, model_arch):
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)
outputs = ipex_model(**tokens)
self.assertTrue(
torch.allclose(outputs["last_hidden_state"], transformers_outputs["last_hidden_state"], atol=1e-4)
)
# Compare tensor outputs
for output_name in {"logits", "last_hidden_state"}:
if output_name in transformers_outputs:
self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
Expand Down

0 comments on commit ac809f8

Please sign in to comment.