From d38c616f7f953d80fd1e2ed2bd9537268cb0fe92 Mon Sep 17 00:00:00 2001 From: eaidova Date: Fri, 3 Nov 2023 18:34:12 +0400 Subject: [PATCH] restrict transformers version for now... --- optimum/intel/openvino/trainer.py | 25 +++++-------------------- setup.py | 2 +- tests/openvino/test_modeling.py | 6 +++--- 3 files changed, 9 insertions(+), 24 deletions(-) diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 0e5521b4db..0bba054ad3 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -65,25 +65,11 @@ from transformers.trainer_utils import ( EvalPrediction, HPSearchBackend, + ShardedDDPOption, TrainOutput, has_length, speed_metrics, ) - - -try: - from transformers.trainer_utils import ShardedDDPOption -except ImportError: - from transformers.utils import ExplicitEnum - - class ShardedDDPOption(ExplicitEnum): - SIMPLE = "simple" - ZERO_DP_2 = "zero_dp_2" - ZERO_DP_3 = "zero_dp_3" - OFFLOAD = "offload" - AUTO_WRAP = "auto_wrap" - - from transformers.utils import ( WEIGHTS_NAME, is_apex_available, @@ -310,10 +296,9 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - sharded_ddp = getattr(self, "sharded_ddp", None) delay_optimizer_creation = ( - sharded_ddp is not None - and sharded_ddp != ShardedDDPOption.SIMPLE + self.sharded_ddp is not None + and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled() or self.fsdp is not None ) @@ -526,7 +511,7 @@ def _inner_training_loop( if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: # deepspeed does its own clipping - if getattr(self, "do_grad_scaling", False): + if self.do_grad_scaling: # AMP: gradients need unscaling self.scaler.unscale_(self.optimizer) @@ -549,7 +534,7 @@ def _inner_training_loop( optimizer_was_run = True if self.deepspeed: pass # called outside the loop - elif getattr(self, "do_grad_scaling", False): + elif self.do_grad_scaling: scale_before = self.scaler.get_scale() self.scaler.step(self.optimizer) self.scaler.update() diff --git a/setup.py b/setup.py index 6d81b98b2a..52494a67d4 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ "onnx", "onnxruntime<1.15.0", ], - "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime"], + "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime", "transformers<4.35.0"], "nncf": ["nncf>=2.6.0"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index af0aa2a6d7..0c518e88f1 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -493,7 +493,7 @@ def test_compare_to_transformers(self, model_arch): set_seed(SEED) ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True) self.assertIsInstance(ov_model.config, PretrainedConfig) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer( "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None @@ -504,7 +504,8 @@ def test_compare_to_transformers(self, model_arch): with torch.no_grad(): transformers_outputs = transformers_model(**tokens) # Compare tensor outputs - self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-2)) + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4), + f"Max diff {torch.abs(ov_outputs.logits - transformers_outputs.logits).max()}") del transformers_model del ov_model gc.collect() @@ -1238,7 +1239,6 @@ def test_compare_to_transformers(self, model_arch): ov_outputs = ov_model(**features, **decoder_inputs) self.assertIn("logits", ov_outputs) - self.assertIsInstance(ov_outputs.logits, TENSOR_ALIAS_TO_TYPE[input_type]) # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(ov_outputs.logits), transformers_outputs.logits, atol=1e-3))