Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Aug 2, 2024
1 parent 86511aa commit b294474
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 29 deletions.
4 changes: 1 addition & 3 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@
# and it implements many new features including short and long form generation, and starts with 2 init tokens
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
else:

class WhisperGenerationMixin:
generate = WhisperForConditionalGeneration.generate
WhisperGenerationMixin = WhisperForConditionalGeneration


if is_transformers_version(">=", "4.43.0"):
Expand Down
30 changes: 4 additions & 26 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,24 +1666,13 @@ def _generate_random_audio_data(self):
def test_compare_to_transformers(self, model_arch):
set_seed(SEED)
model_id = MODEL_NAMES[model_arch]

if is_transformers_version(">=", "4.37"):
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
else:
with self.assertRaises(Exception) as context:
_ = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
self.assertIn(
"Whisper is not available for this version of Transformers, please upgrade to 4.37.0 or later.",
str(context.exception),
)
return

self.assertIsInstance(ov_model.config, PretrainedConfig)
transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
self.assertIsInstance(ov_model.config, PretrainedConfig)

processor = get_preprocessor(model_id)
data = self._generate_random_audio_data()
features = processor.feature_extractor(data, return_tensors="pt")

decoder_start_token_id = transformers_model.config.decoder_start_token_id
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

Expand Down Expand Up @@ -1711,19 +1700,8 @@ def test_compare_to_transformers(self, model_arch):
def test_pipeline(self, model_arch):
set_seed(SEED)
model_id = MODEL_NAMES[model_arch]
model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)

if is_transformers_version(">=", "4.37"):
model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)
else:
with self.assertRaises(Exception) as context:
_ = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)
self.assertIn(
"Whisper is not available for this version of Transformers, please upgrade to 4.37.0 or later.",
str(context.exception),
)
return

model.eval()
processor = get_preprocessor(model_id)
pipe = pipeline(
"automatic-speech-recognition",
Expand Down

0 comments on commit b294474

Please sign in to comment.