19
19
import numpy as np
20
20
import onnxruntime as ort
21
21
from numpy .linalg import norm
22
+ from silero_vad import vad
22
23
23
24
from wespeaker .cli .hub import Hub
24
25
from wespeaker .cli .fbank import logfbank
@@ -28,12 +29,25 @@ class Speaker:
28
29
def __init__ (self , model_path : str , resample_rate : int = 16000 ):
29
30
self .session = ort .InferenceSession (model_path )
30
31
self .resample_rate = resample_rate
32
+ self .vad_model = vad .OnnxWrapper ()
31
33
self .table = {}
32
34
33
- def extract_embedding (self , audio_path : str ):
35
+ def extract_embedding (self , audio_path : str , apply_vad : bool = False ):
34
36
pcm , sample_rate = librosa .load (audio_path , sr = self .resample_rate )
35
37
pcm = pcm * (1 << 15 )
36
- # NOTE: produce the same results as with torchaudio.compliance.kaldi
38
+ if apply_vad :
39
+ # TODO(Binbin Zhang): Refine the segments logic, here we just
40
+ # suppose there is only silence at the start/end of the speech
41
+ segments = vad .get_speech_timestamps (self .vad_model ,
42
+ audio_path ,
43
+ return_seconds = True )
44
+ if len (segments ) > 0 : # remove head and tail silence
45
+ start = int (segments [0 ]['start' ] * sample_rate )
46
+ end = int (segments [- 1 ]['end' ] * sample_rate )
47
+ pcm = pcm [start :end ]
48
+ else : # all silence, nospeech
49
+ return None
50
+
37
51
feats = logfbank (
38
52
pcm ,
39
53
sample_rate ,
@@ -50,9 +64,12 @@ def extract_embedding(self, audio_path: str):
50
64
return embedding
51
65
52
66
def compute_similarity (self , audio_path1 : str , audio_path2 ) -> float :
53
- e1 = self .extract_embedding (audio_path1 )
54
- e2 = self .extract_embedding (audio_path2 )
55
- return self .cosine_distance (e1 , e2 )
67
+ e1 = self .extract_embedding (audio_path1 , True )
68
+ e2 = self .extract_embedding (audio_path2 , True )
69
+ if e1 is None or e2 is None :
70
+ return 0.0
71
+ else :
72
+ return self .cosine_distance (e1 , e2 )
56
73
57
74
def cosine_distance (self , e1 , e2 ):
58
75
return np .dot (e1 , e2 ) / (norm (e1 ) * norm (e2 ))
@@ -109,6 +126,9 @@ def get_args():
109
126
type = int ,
110
127
default = 16000 ,
111
128
help = 'resampling rate' )
129
+ parser .add_argument ('--vad' ,
130
+ action = 'store_true' ,
131
+ help = 'whether to do VAD or not' )
112
132
args = parser .parse_args ()
113
133
return args
114
134
@@ -117,7 +137,7 @@ def main():
117
137
args = get_args ()
118
138
model = load_model (args .language , args .resample_rate )
119
139
if args .task == 'embedding' :
120
- print (model .extract_embedding (args .audio_file ))
140
+ print (model .extract_embedding (args .audio_file , args . vad ))
121
141
elif args .task == 'similarity' :
122
142
print (model .compute_similarity (args .audio_file , args .audio_file2 ))
123
143
elif args .task == 'diarization' :
0 commit comments