diff --git a/tests/composition/test_parallel.py b/tests/composition/test_parallel.py index 31dce0996..c385ea24d 100644 --- a/tests/composition/test_parallel.py +++ b/tests/composition/test_parallel.py @@ -131,10 +131,10 @@ def test_parallel_generate(self): seq_output_length = 32 # Finally, also check if generation works properly - if self.is_speech_model: - input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + input_ids = self.extract_input_ids( + self.get_input_samples(self.generate_input_samples_shape, config=model1.config) + ) + input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (2, seq_output_length)) diff --git a/tests/methods/test_compacter.py b/tests/methods/test_compacter.py index 292fab1ef..2c91b7536 100644 --- a/tests/methods/test_compacter.py +++ b/tests/methods/test_compacter.py @@ -71,10 +71,10 @@ def test_compacter_generate(self): seq_output_length = 32 # Finally, also check if generation works properly - if self.is_speech_model: - input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + input_ids = self.extract_input_ids( + self.get_input_samples(self.generate_input_samples_shape, config=model1.config) + ) + input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (1, seq_output_length)) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index dd443c0d0..9c3b0822a 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -94,10 +94,10 @@ def test_prefix_tuning_generate(self): seq_output_length = 32 # Finally, also check if generation works properly - if self.is_speech_model: - input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + input_ids = self.extract_input_ids( + self.get_input_samples(self.generate_input_samples_shape, config=model1.config) + ) + input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (1, seq_output_length)) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index bafa7e65a..1802aa5c0 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -35,8 +35,10 @@ def ids_tensor(shape, vocab_size, rng=None, name=None): class AdapterTestBase: # If not overriden by subclass, AutoModel should be used. model_class = AutoAdapterModel + tokenizer_name = "tests/fixtures/SiBERT" # Default shape of inputs to use default_input_samples_shape = (3, 64) + generate_input_samples_shape = (1, 4) leave_out_layers = [0, 1] do_run_train_tests = True # default arguments for test_adapter_heads @@ -98,6 +100,9 @@ def assert_adapter_unavailable(self, model, adapter_name): self.assertFalse(adapter_name in model.adapters_config) self.assertEqual(len(model.get_adapter(adapter_name)), 0) + def extract_input_ids(self, inputs): + return inputs["input_ids"] + class VisionAdapterTestBase(AdapterTestBase): default_input_samples_shape = (3, 3, 224, 224) @@ -146,10 +151,14 @@ class SpeechAdapterTestBase(AdapterTestBase): """Base class for speech adapter tests.""" default_input_samples_shape = (3, 80, 3000) # (batch_size, n_mels, enc_seq_len) + generate_input_samples_shape = (1, 80, 3000) is_speech_model = True # Flag for tests to determine if the model is a speech model due to input format difference time_window = 3000 # Time window for audio samples seq_length = 80 + def extract_input_ids(self, inputs): + return inputs["input_features"] + def add_head(self, model, name, head_type="seq2seq_lm", **kwargs): """Adds a head to the model.""" if head_type == "audio_classification": diff --git a/tests/test_adapter_embeddings.py b/tests/test_adapter_embeddings.py index 160828c77..64a07d381 100644 --- a/tests/test_adapter_embeddings.py +++ b/tests/test_adapter_embeddings.py @@ -182,6 +182,6 @@ def _instantiate_tokenizer(self, model): tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) input_data = self.get_input_samples(config=self.config()) else: - tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT") + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) input_data = self.get_input_samples((1, 128), vocab_size=tokenizer.vocab_size, config=model.config) return tokenizer, input_data diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index 541debf35..c0c3812cc 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -175,10 +175,8 @@ def test_seq2seq_lm_head(self): # Finally, also check if generation works properly input_shape = self._get_input_shape() - if self.is_speech_model: - input_ids = self.get_input_samples(input_shape, config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples(input_shape, config=model1.config)["input_ids"] + input_ids = self.extract_input_ids(self.get_input_samples(input_shape, config=model1.config)) + input_ids = input_ids.to(torch_device) # Use a different length for the seq2seq output seq_output_length = self.seq_length + 30