Skip to content

Commit 6bbaf79

Browse files
authored
[cli] add vad (#217)
* [cli] add vad * fix parameter * fix logic error
1 parent 803e08c commit 6bbaf79

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"tqdm",
55
"onnxruntime>=1.12.0",
66
"librosa>=0.8.0",
7+
"silero-vad @ git+https://github.com/pengzhendong/silero-vad.git",
78
]
89

910
setup(

wespeaker/cli/speaker.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import onnxruntime as ort
2121
from numpy.linalg import norm
22+
from silero_vad import vad
2223

2324
from wespeaker.cli.hub import Hub
2425
from wespeaker.cli.fbank import logfbank
@@ -28,12 +29,25 @@ class Speaker:
2829
def __init__(self, model_path: str, resample_rate: int = 16000):
2930
self.session = ort.InferenceSession(model_path)
3031
self.resample_rate = resample_rate
32+
self.vad_model = vad.OnnxWrapper()
3133
self.table = {}
3234

33-
def extract_embedding(self, audio_path: str):
35+
def extract_embedding(self, audio_path: str, apply_vad: bool = False):
3436
pcm, sample_rate = librosa.load(audio_path, sr=self.resample_rate)
3537
pcm = pcm * (1 << 15)
36-
# NOTE: produce the same results as with torchaudio.compliance.kaldi
38+
if apply_vad:
39+
# TODO(Binbin Zhang): Refine the segments logic, here we just
40+
# suppose there is only silence at the start/end of the speech
41+
segments = vad.get_speech_timestamps(self.vad_model,
42+
audio_path,
43+
return_seconds=True)
44+
if len(segments) > 0: # remove head and tail silence
45+
start = int(segments[0]['start'] * sample_rate)
46+
end = int(segments[-1]['end'] * sample_rate)
47+
pcm = pcm[start:end]
48+
else: # all silence, nospeech
49+
return None
50+
3751
feats = logfbank(
3852
pcm,
3953
sample_rate,
@@ -50,9 +64,12 @@ def extract_embedding(self, audio_path: str):
5064
return embedding
5165

5266
def compute_similarity(self, audio_path1: str, audio_path2) -> float:
53-
e1 = self.extract_embedding(audio_path1)
54-
e2 = self.extract_embedding(audio_path2)
55-
return self.cosine_distance(e1, e2)
67+
e1 = self.extract_embedding(audio_path1, True)
68+
e2 = self.extract_embedding(audio_path2, True)
69+
if e1 is None or e2 is None:
70+
return 0.0
71+
else:
72+
return self.cosine_distance(e1, e2)
5673

5774
def cosine_distance(self, e1, e2):
5875
return np.dot(e1, e2) / (norm(e1) * norm(e2))
@@ -109,6 +126,9 @@ def get_args():
109126
type=int,
110127
default=16000,
111128
help='resampling rate')
129+
parser.add_argument('--vad',
130+
action='store_true',
131+
help='whether to do VAD or not')
112132
args = parser.parse_args()
113133
return args
114134

@@ -117,7 +137,7 @@ def main():
117137
args = get_args()
118138
model = load_model(args.language, args.resample_rate)
119139
if args.task == 'embedding':
120-
print(model.extract_embedding(args.audio_file))
140+
print(model.extract_embedding(args.audio_file, args.vad))
121141
elif args.task == 'similarity':
122142
print(model.compute_similarity(args.audio_file, args.audio_file2))
123143
elif args.task == 'diarization':

0 commit comments

Comments
 (0)