-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest.py
104 lines (89 loc) · 3.85 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import re
from timeit import default_timer as timer
import wave
import argparse
#import json
parser = argparse.ArgumentParser(description="Running Whisper TFlite test inference.")
parser.add_argument("-f", "--folder", default="../test-files/", help="Folder with WAV input files")
parser.add_argument("-m", "--model", default="models/whisper-tiny.tflite", help="Path to model")
parser.add_argument("-t", "--threads", default=2, help="Threads used (default: 2)")
parser.add_argument("-l", "--lang", default="en", help="Language used (default: en)")
parser.add_argument("-r", "--runtime", default="1", help="Tensorflow runtime, use '1' (default) for tf.lite or '2' for tflite_runtime")
args = parser.parse_args()
if args.runtime == "1":
print(f'Importing tensorflow (for tf.lite)')
import tensorflow as tf
else:
print(f'Importing tflite_runtime')
import tflite_runtime.interpreter as tf
print(f'Importing numpy')
import numpy as np
#import torch
print(f'Importing whisper')
import whisper
model_path = args.model
print(f'\nLoading tflite model {model_path} ...')
if args.runtime == "1":
interpreter = tf.lite.Interpreter(model_path, num_threads=int(args.threads))
else:
interpreter = tf.Interpreter(model_path, num_threads=int(args.threads))
print(f'Threads: {args.threads}')
interpreter.allocate_tensors()
input_tensor = interpreter.get_input_details()[0]['index']
output_tensor = interpreter.get_output_details()[0]['index']
if ".en" in model_path:
wtokenizer = whisper.tokenizer.get_tokenizer(False, language="en")
else:
wtokenizer = whisper.tokenizer.get_tokenizer(True, language=args.lang)
def transcribe(audio_file):
print(f'\nLoading audio file: {audio_file}')
wf = wave.open(audio_file, "rb")
sample_rate_orig = wf.getframerate()
audio_length = wf.getnframes() * (1 / sample_rate_orig)
if (wf.getnchannels() != 1 or wf.getsampwidth() != 2
or wf.getcomptype() != "NONE" or sample_rate_orig != 16000):
print("Audio file must be WAV format mono PCM.")
exit (1)
wf.close()
print(f'Samplerate: {sample_rate_orig}, length: {audio_length}s')
file_lang = None
lang_search = re.findall(r"(?:^|/)(\w\w)_", audio_file)
if len(lang_search) > 0:
file_lang = lang_search.pop()
if ".en" in model_path and file_lang != "en":
print(f"Language found in file name: {file_lang}")
print("Skipped file to avoid issues with '.en' model.")
return
inference_start = timer()
print(f'\nCalculating mel spectrogram...')
mel_from_file = whisper.audio.log_mel_spectrogram(audio_file)
input_data = whisper.audio.pad_or_trim(mel_from_file, whisper.audio.N_FRAMES)
input_data = np.expand_dims(input_data, 0)
#print("Input data shape:", input_data.shape)
#input_data = np.frombuffer(wf.readframes(wf.getnframes()), np.int16)
#input_data = np.random.randn(1, 256, 256, 3)
print("Invoking interpreter ...")
interpreter.set_tensor(input_tensor, input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_tensor)
print("Preparing output data ...")
output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]['index'])
#output_data = output_data.squeeze()
#print(output_data)
#np.savetxt("output.txt", output_data)
#print(interpreter.get_output_details()[0])
# convert tokens to text
print("Converting tokens ...")
for token in output_data:
#print(token)
token[token == -100] = wtokenizer.eot
text = wtokenizer.decode(token, skip_special_tokens=True)
print(text)
print("\nInference took {:.2f}s for {:.2f}s audio file.".format(
timer() - inference_start, audio_length))
test_files = os.listdir(args.folder)
for file in test_files:
if file.endswith(".wav"):
transcribe(args.folder + file)