Skip to content

Commit

Permalink
fix: silero v5 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhecuican committed Aug 9, 2024
1 parent 0c345db commit 96e225f
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 231 deletions.
4 changes: 3 additions & 1 deletion Data/default_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@
},
"silero":{
"model": "silero_vad.onnx",
"chunkSize": 640
"chunkSize": 1024,
"thresh": 0.1,
"abandonNum": 5
},
"netease_cloud":{
"phone": "xxx",
Expand Down
1 change: 0 additions & 1 deletion Plugins/hass/hass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ QJsonObject Hass::parseParams(const Intent &intent,
const HassService &service) {
QJsonObject params = parseObject(intent, service.params);
intent.toString(0);
qDebug() << params;
return params;
}

Expand Down
19 changes: 19 additions & 0 deletions Test/tst_sherpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "../Utils/config.h"
#include "../Utils/wavfilereader.h"
#include "TestPluginHelper.h"
#include "../Wakeup/Vad/silerovad.h"
#include <QLibrary>
#include <QtTest/QtTest>
using namespace AC;
Expand All @@ -21,6 +22,23 @@ private slots:
Config::instance()->loadConfig();
mplayer = new Player(this);
}
void vad() {
SileroVad* vad = new SileroVad(this);
QFile file(Config::getDataPath("short_test.wav"));
file.open(QIODevice::ReadOnly);
QByteArray data = file.readAll();
file.close();
bool detected = false;
for(int i=0; i<data.size(); i+=vad->getChunkSize()){
if(data.size() - i < vad->getChunkSize()) break;
QByteArray testData = data.mid(i, vad->getChunkSize());
bool detect = vad->detectVoice(testData);
if(detect){
detected = true;
}
}
QCOMPARE(detected, true);
}
void asrSherpa() {
WavFileReader reader;
reader.OpenWavFile(Config::getDataPath("test2.wav").toStdString());
Expand All @@ -45,6 +63,7 @@ private slots:
QFile file(Config::getDataPath("short_test.wav"));
file.open(QIODevice::ReadOnly);
QByteArray data = file.readAll();
file.close();
mplayer->playRaw(data, 8000);
QBuffer buffer(&data);
QMediaPlayer *player = new QMediaPlayer;
Expand Down
211 changes: 6 additions & 205 deletions Wakeup/Vad/silerovad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ SileroVad::SileroVad(QObject *parent) : VadModel(parent) {
QString modelPath =
Config::getDataPath(sileroConfig.value("model").toString());
chunkSize = sileroConfig.value("chunkSize").toInt();
abandonNum = sileroConfig.value("abandonNum").toInt(5);
float thresh = sileroConfig.value("thresh").toDouble(0.1);
samples.resize(chunkSize / 2);
vad = new VadIterator(modelPath.toStdString(), 16000, chunkSize / 32, 0.1);
vad = new VadIterator(modelPath.toStdString(), 16000, chunkSize / 32, thresh);
QJsonObject wakeupConfig = Config::instance()->getConfig("wakeup");
detectSlient = wakeupConfig.find("detectSlient")->toInt();
undetectTimer = new QTimer(this);
Expand All @@ -29,7 +31,8 @@ bool SileroVad::detectVoice(const QByteArray &data) {
samples[i] = intData[i] / 32768.;
}
bool isVoice = vad->vadDetect(samples);
return isVoice;
abandonCurrent++;
return isVoice && abandonCurrent > abandonNum;
}

void SileroVad::stop() {
Expand All @@ -43,6 +46,7 @@ void SileroVad::stop() {
void SileroVad::startDetect(bool isResponse) {
VadModel::startDetect(isResponse);
vad->reset_states();
abandonCurrent = 0;
}

int SileroVad::getChunkSize() { return chunkSize; }
Expand Down Expand Up @@ -135,206 +139,3 @@ bool VadIterator::vadDetect(const std::vector<float> &data) {
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
return speech_prob > threshold;
}

void VadIterator::predict(const std::vector<float> &data) {
// Infer
// Create ort tensors
input.assign(data.begin(), data.end());
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
memory_info, input.data(), input.size(), input_node_dims, 2);
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
memory_info, _state.data(), _state.size(), state_node_dims, 3);

// Clear and add inputs
ort_inputs.clear();
ort_inputs.emplace_back(std::move(input_ort));
ort_inputs.emplace_back(std::move(state_ort));
ort_inputs.emplace_back(std::move(sr_ort));

// Infer
ort_outputs = session->Run(
Ort::RunOptions{nullptr}, input_node_names.data(), ort_inputs.data(),
ort_inputs.size(), output_node_names.data(), output_node_names.size());

// Output probability & update h,c recursively
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
std::memcpy(_state.data(), stateN, size_state * sizeof(float));

// Push forward sample index
current_sample += window_size_samples;

// Reset temp_end when > threshold
if ((speech_prob >= threshold)) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample -
window_size_samples; // minus window_size_samples to get
// precise start time point.
printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate,
speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (temp_end != 0) {
temp_end = 0;
if (next_start < prev_end)
next_start = current_sample - window_size_samples;
}
if (triggered == false) {
triggered = true;

current_speech.start = current_sample - window_size_samples;
}
return;
}

if ((triggered == true) &&
((current_sample - current_speech.start) > max_speech_samples)) {
if (prev_end > 0) {
current_speech.end = prev_end;
speeches.push_back(current_speech);
current_speech = timestamp_t();

// previously reached silence(< neg_thres) and is still not speech(<
// thres)
if (next_start < prev_end)
triggered = false;
else {
current_speech.start = next_start;
}
prev_end = 0;
next_start = 0;
temp_end = 0;

} else {
current_speech.end = current_sample;
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
return;
}
if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) {
if (triggered) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample -
window_size_samples; // minus window_size_samples to
// get precise start time point.
printf("{ speeking: %.3f s (%.3f) %08d}\n",
1.0 * speech / sample_rate, speech_prob,
current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
} else {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample -
window_size_samples; // minus window_size_samples to
// get precise start time point.
printf("{ silence: %.3f s (%.3f) %08d}\n",
1.0 * speech / sample_rate, speech_prob,
current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
}
return;
}

// 4) End
if ((speech_prob < (threshold - 0.15))) {
#ifdef __DEBUG_SPEECH_PROB___
float speech = current_sample - window_size_samples -
speech_pad_samples; // minus window_size_samples to get
// precise start time point.
printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate,
speech_prob, current_sample - window_size_samples);
#endif //__DEBUG_SPEECH_PROB___
if (triggered == true) {
if (temp_end == 0) {
temp_end = current_sample;
}
if (current_sample - temp_end > min_silence_samples_at_max_speech)
prev_end = temp_end;
// a. silence < min_slience_samples, continue speaking
if ((current_sample - temp_end) < min_silence_samples) {

}
// b. silence >= min_slience_samples, end speaking
else {
current_speech.end = temp_end;
if (current_speech.end - current_speech.start >
min_speech_samples) {
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
}
} else {
// may first windows see end state.
}
return;
}
}

void VadIterator::process(const std::vector<float> &input_wav) {
reset_states();

audio_length_samples = input_wav.size();

for (int j = 0; j < audio_length_samples; j += window_size_samples) {
if (j + window_size_samples > audio_length_samples)
break;
std::vector<float> r{&input_wav[0] + j,
&input_wav[0] + j + window_size_samples};
predict(r);
}

if (current_speech.start >= 0) {
current_speech.end = audio_length_samples;
speeches.push_back(current_speech);
current_speech = timestamp_t();
prev_end = 0;
next_start = 0;
temp_end = 0;
triggered = false;
}
}

void VadIterator::process(const std::vector<float> &input_wav,
std::vector<float> &output_wav) {
process(input_wav);
collect_chunks(input_wav, output_wav);
}

void VadIterator::collect_chunks(const std::vector<float> &input_wav,
std::vector<float> &output_wav) {
output_wav.clear();
for (int i = 0; i < (int)speeches.size(); i++) {
#ifdef __DEBUG_SPEECH_PROB___
std::cout << speeches[i].c_str() << std::endl;
#endif // #ifdef __DEBUG_SPEECH_PROB___
std::vector<float> slice(&input_wav[speeches[i].start],
&input_wav[speeches[i].end]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
}
}

void VadIterator::drop_chunks(const std::vector<float> &input_wav,
std::vector<float> &output_wav) {
output_wav.clear();
int current_start = 0;
for (int i = 0; i < (int)speeches.size(); i++) {

std::vector<float> slice(&input_wav[current_start],
&input_wav[speeches[i].start]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
current_start = speeches[i].end;
}

std::vector<float> slice(&input_wav[current_start],
&input_wav[input_wav.size()]);
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
}
38 changes: 15 additions & 23 deletions Wakeup/Vad/silerovad.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,53 +88,43 @@ class VadIterator {

public:
void reset_states();
void process(const std::vector<float> &input_wav);
void process(const std::vector<float> &input_wav,
std::vector<float> &output_wav);
void collect_chunks(const std::vector<float> &input_wav,
std::vector<float> &output_wav);
const std::vector<timestamp_t> get_speech_timestamps() const {
return speeches;
}
void drop_chunks(const std::vector<float> &input_wav,
std::vector<float> &output_wav);
bool vadDetect(const std::vector<float> &input_wav);

private:
// model config
int64_t window_size_samples; // Assign when init, support 256 512 768 for
// 8k; 512 1024 1536 for 16k.
int sample_rate; // Assign when init support 16000 or 8000
int sr_per_ms; // Assign when init, support 8 or 16
float threshold;
uint32_t min_silence_samples; // sr_per_ms * #ms
uint32_t min_silence_samples_at_max_speech; // sr_per_ms * #98
int min_speech_samples; // sr_per_ms * #ms
int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k.
int sample_rate; //Assign when init support 16000 or 8000
int sr_per_ms; // Assign when init, support 8 or 16
float threshold;
int min_silence_samples; // sr_per_ms * #ms
int min_silence_samples_at_max_speech; // sr_per_ms * #98
int min_speech_samples; // sr_per_ms * #ms
float max_speech_samples;
int speech_pad_samples; // usually a
int speech_pad_samples; // usually a
int audio_length_samples;

// model states
bool triggered = false;
unsigned int temp_end = 0;
unsigned int current_sample = 0;
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
unsigned int current_sample = 0;
// MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes
int prev_end;
int next_start = 0;

// Output timestamp
//Output timestamp
std::vector<timestamp_t> speeches;
timestamp_t current_speech;

// Onnx model
// Inputs
std::vector<Ort::Value> ort_inputs;

std::vector<const char *> input_node_names = {"input", "state", "sr"};
std::vector<float> input;
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
std::vector<float> _state;
std::vector<int64_t> sr;

int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128};
const int64_t sr_node_dims[1] = {1};
Expand Down Expand Up @@ -170,6 +160,8 @@ class SileroVad : public VadModel {
qint64 currentSlient;
bool findVoice;
std::vector<float> samples;
int abandonNum = 5;
int abandonCurrent = 0;
};

#endif // SILEROVAD_H
2 changes: 1 addition & 1 deletion Wakeup/wakeup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ Wakeup::Wakeup(Player *player, QObject *parent)
isPlaying = player->isPlaying();
player->pause();
recorder->pause();
player->playSoundEffect(Config::getDataPath("start.wav"), false);
player->playSoundEffect(Config::getDataPath("start.wav"), true);
recorder->resume();
vadModel->startDetect();
detectState = VAD;
Expand Down
11 changes: 11 additions & 0 deletions doc/配置.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ ls -lh sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
| ------ | ------------------- | ------------------------------ |
| vadRes | vad_aicar_v0.16.bin | 资源文件,可自行在控制台中下载 |

## silerovad

选项: `silero`

| 选项 | 默认值 | 含义 |
| ------ | ------------------- | ------------------------------ |
| model | silero_vad.onnx | 模型路径 |
| chunkSize | 1024 | silero v5中默认为1024,v4默认为640 |
| thresh | 0.1 | 阈值,越大越难检测出声音 |
| abandonNum | 5 | 检测开始时抛弃的块数量 |

# asr

# sherpa asr
Expand Down

0 comments on commit 96e225f

Please sign in to comment.