-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
27 lines (23 loc) · 978 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model
from IPython.display import Audio
import tensorflow as tf
def get_audio(path):
audio,_ = tf.audio.decode_wav(tf.io.read_file(path),1)
return audio
def inference_preprocess(path,batching_size=12000):
audio = get_audio(path)
audio_len = audio.shape[0]
batches = []
for i in range(0,audio_len-batching_size,batching_size):
batches.append(audio[i:i+batching_size])
batches.append(audio[-batching_size:])
diff = audio_len - (i + batching_size)
return tf.stack(batches), diff
def predict(path, model_path):
model = load_model(model_path)
test_data,diff = inference_preprocess(path)
predictions = model.predict(test_data)
final_op = tf.reshape(predictions[:-1],((predictions.shape[0]-1)*predictions.shape[1],1))
final_op = tf.concat((final_op,predictions[-1][-diff:]),axis=0)
return Audio(tf.squeeze(predict(final_op)),rate=16000)