diff --git a/.github/workflows/test_openvino_basic.yml b/.github/workflows/test_openvino_basic.yml index 432819dfdd..141b94425e 100644 --- a/.github/workflows/test_openvino_basic.yml +++ b/.github/workflows/test_openvino_basic.yml @@ -41,7 +41,7 @@ jobs: # optimum or transformers to a specific version # Install PyTorch CPU to prevent unnecessary downloading/installing of CUDA packages pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install .[tests] openvino onnxruntime ${{ matrix.optimum}} + pip install .[tests] openvino ${{ matrix.optimum}} - name: Pip freeze run: pip freeze @@ -52,6 +52,7 @@ jobs: - name: Slow tests run: | - pytest tests/openvino/test_modeling.py -s -m "run_slow" --durations=0 + pip install nncf + pytest tests/openvino -s -m "run_slow" --durations=0 env: RUN_SLOW: 1 diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 430101db18..17bcea965b 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Optional from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from transformers.utils.quantization_config import QuantizationMethod from ...exporters import TasksManager from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available @@ -289,7 +288,7 @@ def _get_default_int4_config(model_id_or_path, library_name): "all_layers": None if is_int8 else self.args.all_layers, "dataset": self.args.dataset, "num_samples": self.args.num_samples, - "quant_method": QuantizationMethod.AWQ if self.args.awq else None, + "quant_method": "awq" if self.args.awq else "default", "sensitivity_metric": self.args.sensitivity_metric, "scale_estimation": self.args.scale_estimation, } diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index ec1d4805d0..75e6863169 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -167,24 +167,27 @@ def __init__( ) self.multi_query_group_num = normalized_config.multi_query_group_num self.head_dim = normalized_config.kv_channels + self.standart_cache_layout = hasattr(normalized_config, "rope_ratio") def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - past_key_shape = ( - self.sequence_length, - self.batch_size, - self.multi_query_group_num, - self.head_dim, - ) - past_value_shape = ( - self.sequence_length, - self.batch_size, - self.multi_query_group_num, - self.head_dim, - ) + if not self.standart_cache_layout: + pkv_shape = ( + self.sequence_length, + self.batch_size, + self.multi_query_group_num, + self.head_dim, + ) + else: + pkv_shape = ( + self.batch_size, + self.multi_query_group_num, + self.sequence_length, + self.head_dim, + ) return [ ( - self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), - self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(pkv_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(pkv_shape, framework=framework, dtype=float_dtype), ) for _ in range(self.num_layers) ] @@ -229,7 +232,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): and "attention_mask" in dummy_inputs ): # Obtain the past sequence length from the value instead of the key (Bloom). ChatGLM has seq_len in 0 dim instead of -2 - past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[0] + seq_len_dim = 0 if not hasattr(self._normalized_config, "rope_ratio") else -2 + past_present_length = ( + dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[seq_len_dim] + ) dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["attention_mask"], @@ -260,9 +266,18 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire decoder_sequence_name = "past_sequence_length + present_lenght" name = "present" + is_v4 = hasattr(self._normalized_config, "rope_ratio") for i in range(self._normalized_config.num_layers): - inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name} - inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.key"] = ( + {1: "batch_size", 0: decoder_sequence_name} + if not is_v4 + else {0: "batch_size", 2: decoder_sequence_name} + ) + inputs_or_outputs[f"{name}.{i}.value"] = ( + {1: "batch_size", 0: decoder_sequence_name} + if not is_v4 + else {0: "batch_size", 2: decoder_sequence_name} + ) def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 7a98f13e1c..629f47aa72 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -190,7 +190,7 @@ def _chatglm_transformer_forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) - if self.pre_seq_len is not None: + if getattr(self, "pre_seq_len", None) is not None: if past_key_values is None: past_key_values = self.get_prompt( batch_size=batch_size, @@ -285,6 +285,17 @@ def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, return context_layer +def _glm4_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask.to(torch.float32) + ) + context_layer = context_layer.transpose(1, 2).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + return context_layer + + class ChatGLMModelPatcher(DecoderModelPatcher): def __init__( self, @@ -293,21 +304,25 @@ def __init__( model_kwargs: Dict[str, Any], ): super().__init__(config, model, model_kwargs) - - self.original_chatglm_transformer_forward = model.transformer.forward + self.is_v4 = hasattr(self._model.config, "rope_ratio") def __enter__(self): super().__enter__() - self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer) + + if not self.is_v4: + self._model.transformer._orig_forward = self._model.transformer.forward + self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer) for block in self._model.transformer.encoder.layers: block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward block.self_attention.core_attention.forward = types.MethodType( - _chatglm2_core_attention_forward, block.self_attention.core_attention + _chatglm2_core_attention_forward if not self.is_v4 else _glm4_core_attention_forward, + block.self_attention.core_attention, ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - self._model.transformer.forward = self.original_chatglm_transformer_forward + if hasattr(self._model.transformer, "_orig_forward"): + self._model.transformer.forward = self._model.transformer._orig_forward for block in self._model.transformer.encoder.layers: block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index c90c2211ed..8ca42b67aa 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -213,7 +213,7 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): # By default, batch is the 0-th but chatglm uses 1-st dimension as batch # TODO: Deduce from a model via ordinal reshape (?) and topology - batch_dim = 1 if config.model_type == "chatglm" else 0 + batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim) num_attention_heads = config.num_attention_heads if config.model_type == "bloom" else 1 diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index ab54e257c3..30e550b54a 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -20,7 +20,7 @@ import torch from transformers import PretrainedConfig -from transformers.utils.quantization_config import QuantizationConfigMixin, QuantizationMethod +from transformers.utils.quantization_config import QuantizationConfigMixin from optimum.configuration_utils import BaseConfig @@ -78,6 +78,7 @@ class OVQuantizationMethod(str, Enum): DEFAULT = "default" HYBRID = "hybrid" + AWQ = "awq" @dataclass @@ -171,7 +172,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase): entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class. num_samples (`int`, *optional*): The maximum number of samples composing the calibration dataset. - quant_method (`str`, defaults of OVQuantizationMethod.DEFAULT): + quant_method (`str or OVQuantizationMethod`, defaults of OVQuantizationMethod.DEFAULT): Weight compression method to apply. Possible options: - "default": default weight quantization will be applied. - "awq": compressed weights will be computed according to the Activation-Aware-Quantization (AWQ) @@ -199,7 +200,7 @@ def __init__( sensitivity_metric: Optional[str] = None, ignored_scope: Optional[dict] = None, num_samples: Optional[int] = None, - quant_method: Union[QuantizationMethod, OVQuantizationMethod] = OVQuantizationMethod.DEFAULT, + quant_method: Union[str, OVQuantizationMethod] = OVQuantizationMethod.DEFAULT, scale_estimation: bool = None, **kwargs, ): @@ -210,7 +211,7 @@ def __init__( self.ratio = ratio self.all_layers = all_layers self.sensitivity_metric = sensitivity_metric - self.quant_method = quant_method + self.quant_method = OVQuantizationMethod(quant_method) if isinstance(quant_method, str) else quant_method self.scale_estimation = scale_estimation self.post_init() diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 1669cb8143..067b3e5d5d 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -328,9 +328,9 @@ def _reshape( shapes[inputs][0] = -1 input_name = inputs.get_any_name() if input_name.startswith("past_key_values"): - if ( - len(inputs.partial_shape) == 3 and input_name.endswith("value") - ) or self.config.model_type == "chatglm": + if (len(inputs.partial_shape) == 3 and input_name.endswith("value")) or ( + self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio") + ): shapes[inputs][1] = -1 else: shapes[inputs][2] = -1 @@ -421,7 +421,7 @@ def prepare_inputs( model_inputs = self.model.input(input_name) dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()] shape = model_inputs.get_partial_shape() - if self.config.model_type == "chatglm": + if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"): shape[0] = 0 shape[1] = batch_size else: @@ -573,7 +573,7 @@ def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_ke tuple( ( past_state[indicies] - if not self.config.model_type == "chatglm" + if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio")) else past_state[:, indicies, ...] ) for past_state in layer_past @@ -607,7 +607,13 @@ def _deduplicate_inputs(self, model_inputs: Dict): upd_batch_size = indicies.shape[0] if self.config.model_type == "bloom": upd_batch_size *= self.config.num_attention_heads - shape[0 if not self.config.model_type == "chatglm" else 1] = upd_batch_size + shape[ + ( + 0 + if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio")) + else 1 + ) + ] = upd_batch_size upd_model_inputs[input_name] = Tensor(dtype, shape) upd_model_inputs["input_ids"] = unique_input_ids if "beam_idx" in model_inputs: @@ -675,7 +681,7 @@ def _get_past_length(self, past_key_values=None): ): return past_key_values[0].shape[-2] seq_length_dim = -2 - if self.config.model_type == "chatglm": + if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"): seq_length_dim = 0 elif self.config.model_type == "qwen": seq_length_dim = 1 diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 18b65e8002..8568f4d4df 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -38,7 +38,6 @@ from transformers import AutoTokenizer, DataCollator, PreTrainedModel, default_data_collator from transformers.pytorch_utils import Conv1D from transformers.utils import is_accelerate_available -from transformers.utils.quantization_config import QuantizationMethod from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.tasks import TasksManager @@ -828,7 +827,7 @@ def _weight_only_quantization( group_size=config.group_size, all_layers=config.all_layers, sensitivity_metric=sensitivity_metric, - awq=config.quant_method == QuantizationMethod.AWQ or None, + awq=getattr(config.quant_method, "name", "") == "AWQ" or None, ignored_scope=config.get_ignored_scope_instance(), dataset=dataset, subset_size=config.num_samples if config.num_samples else 128, diff --git a/pyproject.toml b/pyproject.toml index 62589e113c..643e25f21d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,3 +29,8 @@ line-length = 119 [tool.ruff.isort] lines-after-imports = 2 known-first-party = ["optimum"] + +[tool.pytest.ini_options] +markers = [ + "run_slow", +] \ No newline at end of file diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 7be12ce6e0..b52d2094cc 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -643,6 +643,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "xverse", "internlm", "jais", + "glm4", ) if is_transformers_version(">=", "4.40.0"): @@ -675,6 +676,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "internlm", "codegen2", "arctic", + "glm4", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -716,7 +718,7 @@ def test_compare_to_transformers(self, model_arch): set_seed(SEED) transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) - if model_arch in ["qwen", "arctic"]: + if model_arch in ["qwen", "arctic", "glm4"]: transformers_model.to(torch.float32) with torch.no_grad(): @@ -729,7 +731,7 @@ def test_compare_to_transformers(self, model_arch): if model_arch == "qwen": return - if model_arch not in ["chatglm", "persimmon"]: + if model_arch not in ["chatglm", "glm4", "persimmon"]: tokenizer.pad_token_id = tokenizer.eos_token_id if model_arch == "persimmon": @@ -990,14 +992,21 @@ def test_beam_search(self, model_arch): ov_model_stateless.config.eos_token_id = None transformers_model.config.eos_token_id = None - for idx, gen_config in enumerate(gen_configs): + for gen_config in gen_configs: if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]: continue + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config) ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs), f"generation config : {idx}") + self.assertTrue( + torch.equal(ov_stateful_outputs, transformers_outputs), + f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model_stateful output {ov_stateful_outputs}", + ) ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config) - self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs), f"generation config : {idx}") + self.assertTrue( + torch.equal(ov_stateless_outputs, transformers_outputs), + f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model_stateless output {ov_stateless_outputs}", + ) class OVModelForMaskedLMIntegrationTest(unittest.TestCase): diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 637876511d..df727eb10d 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -22,7 +22,7 @@ from enum import Enum from functools import partial from typing import Union - +import pytest import evaluate import numpy as np import torch @@ -37,6 +37,7 @@ TrainingArguments, default_data_collator, ) +from transformers.testing_utils import slow from transformers.utils.quantization_config import QuantizationMethod from optimum.intel import ( @@ -173,7 +174,6 @@ def preprocess_function(examples, tokenizer): class OVWeightCompressionTest(unittest.TestCase): - # TODO : add models SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = ( (OVModelForSequenceClassification, "bert", 70, 70), (OVModelForCausalLM, "gpt2", 44, 44), @@ -181,7 +181,6 @@ class OVWeightCompressionTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 62, 86),) SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 0, 148),) - SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "gpt2", 44, 44),) LOAD_IN_4_BITS_SCOPE = ( @@ -236,6 +235,20 @@ class OVWeightCompressionTest(unittest.TestCase): ), 16, ), + ( + OVModelForCausalLM, + "llama_awq", + dict( + bits=4, + sym=True, + group_size=16, + ratio=0.8, + sensitivity_metric="mean_activation_magnitude", + dataset="c4", + quant_method="awq", + ), + 16, + ), ) SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = ( @@ -347,7 +360,6 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict()) @parameterized.expand(SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS) - @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") def test_ovmodel_8bit_weight_compression_stateful(self, model_cls, model_name, expected_pt_int8, expected_ov_int8): task = model_cls.export_feature model_id = MODEL_NAMES[model_name] @@ -415,9 +427,9 @@ def test_ovmodel_hybrid_quantization_with_custom_dataset( ] model = model_cls.from_pretrained(model_id, export=True) quantizer = OVQuantizer(model) - quantization_config = OVWeightQuantizationConfig( - bits=8, num_samples=3, quant_method=OVQuantizationMethod.HYBRID - ) + quantization_config = OVWeightQuantizationConfig(bits=8, num_samples=3, quant_method="hybrid") + self.assertIsInstance(quantization_config.quant_method, OVQuantizationMethod.HYBRID) + quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config), calibration_dataset=dataset) num_fake_quantize, num_int8, num_int4 = get_num_quantized_nodes(model.unet) self.assertEqual(expected_num_fake_quantize, num_fake_quantize) @@ -456,7 +468,7 @@ def test_ovmodel_4bit_auto_compression_with_config( with tempfile.TemporaryDirectory() as tmp_dir: quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config) model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config) - if quantization_config.quant_method == QuantizationMethod.AWQ or quantization_config.scale_estimation: + if quantization_config.quant_method.lower() == "awq" or quantization_config.scale_estimation: # TODO: Check that AWQ and SE was actually applied pass @@ -473,7 +485,6 @@ def test_ovmodel_4bit_auto_compression_with_config( self.assertEqual(openvino_config.dtype, "int4") @parameterized.expand(((OVModelForCausalLM, "gpt2"),)) - @unittest.skipIf(not IS_SUPPORT_STATEFUL, "Stateful models supported only in 2023.3 and above") def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_type): model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=True) self.assertTrue(model.stateful) @@ -588,9 +599,11 @@ def test_ovmodel_4bit_dynamic_with_config(self, model_cls, model_name, quantizat class OVQuantizerQATest(unittest.TestCase): - SUPPORTED_ARCHITECTURES = (("hf-internal-testing/tiny-random-BertForQuestionAnswering",),) + SUPPORTED_ARCHITECTURES = ("hf-internal-testing/tiny-random-BertForQuestionAnswering",) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_automodel_static_quantization(self, model_name): def preprocess_function(examples, tokenizer): return tokenizer( @@ -630,6 +643,8 @@ def preprocess_function(examples, tokenizer): self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict()) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_ovmodel_static_quantization(self, model_name): def preprocess_function(examples, tokenizer): return tokenizer( @@ -670,12 +685,13 @@ def preprocess_function(examples, tokenizer): class OVTrainerTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("distilbert-base-uncased", 67, 38),) + SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("albert", 65, 39),) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) def test_aware_training_quantization(self, model_name, expected_fake_quantize, expected_int8): - model = AutoModelForSequenceClassification.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) + model_id = MODEL_NAMES[model_name] + model = AutoModelForSequenceClassification.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) ov_config = OVConfig() dataset = load_dataset("glue", "sst2") dataset = dataset.map( diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index ab6f6f21a6..e735a07fb4 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -19,6 +19,7 @@ import numpy as np import PIL +import pytest import torch from diffusers import ( StableDiffusionPipeline, @@ -29,6 +30,7 @@ from diffusers.utils.testing_utils import floats_tensor from openvino.runtime.ie_api import CompiledModel from parameterized import parameterized +from transformers.testing_utils import slow from utils_tests import MODEL_NAMES, SEED from optimum.intel import ( @@ -106,6 +108,8 @@ def test_num_images_per_prompt(self, model_arch: str): self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_callback(self, model_arch: str): MODEL_NAMES[model_arch] @@ -178,6 +182,8 @@ def test_compare_diffusers_pipeline(self, model_arch: str): self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-1)) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_num_images_per_prompt_static_model(self, model_arch: str): model_id = MODEL_NAMES[model_arch] pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False) @@ -260,6 +266,8 @@ def test_image_reproducibility(self, model_arch: str): self.assertFalse(np.array_equal(ov_outputs_1.images[0], ov_outputs_3.images[0])) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_num_images_per_prompt_static_model(self, model_arch: str): model_id = MODEL_NAMES[model_arch] pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False) @@ -332,6 +340,8 @@ def test_compare_diffusers_pipeline(self, model_arch: str): self.assertTrue(np.allclose(outputs[0, -3:, -3:, -1].flatten(), expected_slice, atol=1e-1)) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_num_images_per_prompt_static_model(self, model_arch: str): model_id = MODEL_NAMES[model_arch] pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False) @@ -420,6 +430,8 @@ def test_image_reproducibility(self, model_arch: str): self.assertFalse(np.array_equal(ov_outputs_1.images[0], ov_outputs_3.images[0])) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_num_images_per_prompt_static_model(self, model_arch: str): model_id = MODEL_NAMES[model_arch] pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False) @@ -458,6 +470,8 @@ def test_inference(self): self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow def test_num_images_per_prompt_static_model(self, model_arch: str): model_id = MODEL_NAMES[model_arch] pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False) @@ -525,6 +539,8 @@ def test_compare_to_diffusers(self, model_arch: str): self.assertEqual(pipeline.device.type, ov_pipeline.device) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @pytest.mark.run_slow + @slow @unittest.skipIf(is_diffusers_version("<=", "0.21.4"), "not supported with this diffusers version") def test_num_images_per_prompt_static_model(self, model_arch: str): model_id = MODEL_NAMES[model_arch] diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py index 25d430079c..375fc6e4a1 100644 --- a/tests/openvino/test_training.py +++ b/tests/openvino/test_training.py @@ -28,6 +28,7 @@ import cpuinfo import evaluate import numpy as np +import pytest import torch from datasets import load_dataset from nncf.experimental.torch.sparsity.movement.algo import MovementSparsityController @@ -41,6 +42,7 @@ AutoTokenizer, default_data_collator, ) +from transformers.testing_utils import slow from transformers.trainer_utils import EvalPrediction, TrainOutput from transformers.utils import WEIGHTS_NAME @@ -613,6 +615,8 @@ class OVTrainerImageClassificationTrainingTest(OVTrainerBaseTrainingTest): task = "image-classification" @parameterized.expand(OVTRAINER_IMAGE_CLASSIFICATION_TEST_DESCRIPTORS.items()) + @pytest.mark.run_slow + @slow @unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op") def test_training(self, _, desc: OVTrainerTestDescriptor): self.run_ovtrainer_training_checks(desc) @@ -791,6 +795,8 @@ class OVTrainerAudioClassificationTrainingTest(OVTrainerBaseTrainingTest): task = "audio-classification" @parameterized.expand(OVTRAINER_AUDIO_CLASSIFICATION_TEST_DESCRIPTORS.items()) + @pytest.mark.run_slow + @slow def test_training(self, _, desc: OVTrainerTestDescriptor): self.run_ovtrainer_training_checks(desc) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 7f16a8e053..09919047cb 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -129,6 +129,7 @@ "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", "xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM", "xverse": "katuni4ka/tiny-random-xverse", + "glm4": "katuni4ka/tiny-random-glm4", }