Skip to content

Commit

Permalink
skip assited decoding unit test for models using paged attention
Browse files Browse the repository at this point in the history
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
  • Loading branch information
kaixuanliu committed Nov 13, 2024
1 parent 76d32be commit 068dde6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
"mpt",
"opt",
)
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "distilgpt2", "falcon")
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama", "llama2", "distilgpt2", "falcon", "gpt2")
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.0

Expand Down Expand Up @@ -264,8 +264,8 @@ def test_pipeline(self, model_arch):

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_assisted_decoding(self, model_arch):
# Patched models are not support assisted decoding if ipex < 2.5.
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.4.0"):
# assist decoding does not support static cache now
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
return
model_id = MODEL_NAMES[model_arch]
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_compare_to_transformers(self, model_arch):
self.assertIn("logits", outputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits))
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-4))
self.assertTrue(torch.allclose(init_model_outputs.logits, transformers_outputs.logits, atol=1e-4))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand Down

0 comments on commit 068dde6

Please sign in to comment.