Skip to content

Commit

Permalink
udpate test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jul 4, 2024
1 parent 4ee9c99 commit d8fed91
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.12"]
transformers-version: ["4.36.0", "4.41.*","4.42.*"]
transformers-version: ["4.36.0","4.42.*"]
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand Down
5 changes: 0 additions & 5 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,11 +748,6 @@ def test_compare_to_transformers(self, model_arch):
)

ov_outputs = ov_model.generate(**tokens, generation_config=gen_config)

# TODO: update _update_model_kwargs_for_generation so that it's compatibile with transformers >= v4.42.0
if model_arch not in ["chatglm", "glm4"] and is_transformers_version(">=", "4.42.0"):
return

transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_outputs, transformers_outputs))

Expand Down
9 changes: 2 additions & 7 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@

class OVQuantizerTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_TORCH_MODEL = (
(OVModelForSequenceClassification, "bert", 22, 35),
(OVModelForCausalLM, "gpt2", 21, 3),
(OVModelForSequenceClassification, "bert", 32 if is_transformers_version("<", "4.41.0") else 22, 35),
(OVModelForCausalLM, "gpt2", 41 if is_transformers_version("<", "4.42.0") else 21, 3),
)
SUPPORTED_ARCHITECTURES_OV_MODEL = (
(OVModelForSequenceClassification, "bert", 32, 35),
Expand All @@ -90,11 +90,6 @@ def test_automodel_static_quantization(self, model_cls, model_name, expected_fak
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
file_name = "openvino_quantized_model.xml"

if is_transformers_version("<", "4.41.0") and model_name == "bert":
expected_fake_quantize = 32
if is_transformers_version("<", "4.42.0") and model_name == "gpt2":
expected_fake_quantize = 41

def preprocess_function(examples, tokenizer):
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)

Expand Down

0 comments on commit d8fed91

Please sign in to comment.