Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jun 20, 2024
1 parent 4875dfd commit 0772eb1
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,14 +971,21 @@ def test_beam_search(self, model_arch):
ov_model_stateless.config.eos_token_id = None
transformers_model.config.eos_token_id = None

for idx, gen_config in enumerate(gen_configs):
for gen_config in gen_configs:
if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]:
continue

transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs), f"generation config : {idx}")
self.assertTrue(
torch.equal(ov_stateful_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model_stateful output {ov_stateful_outputs}",
)
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs), f"generation config : {idx}")
self.assertTrue(
torch.equal(ov_stateless_outputs, transformers_outputs),
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model_stateless output {ov_stateless_outputs}",
)


class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
Expand Down

0 comments on commit 0772eb1

Please sign in to comment.