Skip to content

Commit

Permalink
support ngram and fst hotword for 2pass-offline (#1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
lyblsgo authored Dec 26, 2023
1 parent b635c06 commit b882590
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 31 deletions.
37 changes: 33 additions & 4 deletions runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std:
}

void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids, int audio_fs,
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_) {
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_,
float glob_beam, float lat_beam, float am_scale, int inc_bias, unordered_map<string, int> hws_map) {

struct timeval start, end;
long seconds = 0;
float n_total_length = 0.0f;
long n_total_time = 0;

FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_scale);
// load hotwords list and build graph
FunWfstDecoderLoadHwsRes(decoder_handle, inc_bias, hws_map);

std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);

Expand Down Expand Up @@ -90,7 +95,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
} else {
is_final = false;
}
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
if (result)
{
FunASRFreeResult(result);
Expand Down Expand Up @@ -139,7 +145,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
is_final = false;
}
gettimeofday(&start, NULL);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final,
sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
Expand Down Expand Up @@ -197,6 +204,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<stri
*total_time = n_total_time;
}
}
FunWfstDecoderUnloadHwsRes(decoder_handle);
FunASRWfstDecoderUninit(decoder_handle);
FunTpassOnlineUninit(tpass_online_handle);
}

Expand All @@ -215,6 +224,11 @@ int main(int argc, char** argv)
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");

TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
Expand All @@ -231,6 +245,11 @@ int main(int argc, char** argv)
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(itn_dir);
cmd.add(lm_dir);
cmd.add(global_beam);
cmd.add(lattice_beam);
cmd.add(am_scale);
cmd.add(fst_inc_wts);
cmd.add(wav_path);
cmd.add(audio_fs);
cmd.add(asr_mode);
Expand All @@ -248,6 +267,7 @@ int main(int argc, char** argv)
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(itn_dir, ITN_DIR, model_path);
GetValue(lm_dir, LM_DIR, model_path);
GetValue(wav_path, WAV_PATH, model_path);
GetValue(asr_mode, ASR_MODE, model_path);

Expand All @@ -272,6 +292,14 @@ int main(int argc, char** argv)
LOG(ERROR) << "FunTpassInit init failed";
exit(-1);
}
float glob_beam = 3.0f;
float lat_beam = 3.0f;
float am_sc = 10.0f;
if (lm_dir.isSet()) {
glob_beam = global_beam.getValue();
lat_beam = lattice_beam.getValue();
am_sc = am_scale.getValue();
}

gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
Expand Down Expand Up @@ -321,7 +349,8 @@ int main(int argc, char** argv)
int rtf_threds = thread_num_.getValue();
for (int i = 0; i < rtf_threds; i++)
{
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_));
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_,
glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map));
}

for (auto& thread : threads)
Expand Down
32 changes: 30 additions & 2 deletions runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ int main(int argc, char** argv)
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");

Expand All @@ -65,6 +70,11 @@ int main(int argc, char** argv)
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(lm_dir);
cmd.add(global_beam);
cmd.add(lattice_beam);
cmd.add(am_scale);
cmd.add(fst_inc_wts);
cmd.add(itn_dir);
cmd.add(wav_path);
cmd.add(audio_fs);
Expand All @@ -81,6 +91,7 @@ int main(int argc, char** argv)
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(lm_dir, LM_DIR, model_path);
GetValue(itn_dir, ITN_DIR, model_path);
GetValue(wav_path, WAV_PATH, model_path);
GetValue(asr_mode, ASR_MODE, model_path);
Expand All @@ -106,6 +117,16 @@ int main(int argc, char** argv)
LOG(ERROR) << "FunTpassInit init failed";
exit(-1);
}
float glob_beam = 3.0f;
float lat_beam = 3.0f;
float am_sc = 10.0f;
if (lm_dir.isSet()) {
glob_beam = global_beam.getValue();
lat_beam = lattice_beam.getValue();
am_sc = am_scale.getValue();
}
// init wfst decoder
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc);

gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
Expand Down Expand Up @@ -146,6 +167,9 @@ int main(int argc, char** argv)
wav_ids.emplace_back(default_id);
}

// load hotwords list and build graph
FunWfstDecoderLoadHwsRes(decoder_handle, fst_inc_wts.getValue(), hws_map);

std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);
// init online features
std::vector<int> chunk_size = {5,10,5};
Expand Down Expand Up @@ -191,7 +215,9 @@ int main(int argc, char** argv)
is_final = false;
}
gettimeofday(&start, NULL);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm",
(ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
Expand Down Expand Up @@ -235,10 +261,12 @@ int main(int argc, char** argv)
}
}
}


FunWfstDecoderUnloadHwsRes(decoder_handle);
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
FunASRWfstDecoderUninit(decoder_handle);
FunTpassOnlineUninit(tpass_online_handle);
FunTpassUninit(tpass_handle);
return 0;
Expand Down
1 change: 0 additions & 1 deletion runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
// warm up
for (size_t i = 0; i < 1; i++)
{
FunOfflineReset(asr_handle, decoder_handle);
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs, true, decoder_handle);
if(result){
FunASRFreeResult(result);
Expand Down
2 changes: 1 addition & 1 deletion runtime/onnxruntime/bin/funasr-onnx-offline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ int main(int argc, char** argv)
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml ", false, "", "string");
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
Expand Down
2 changes: 1 addition & 1 deletion runtime/onnxruntime/include/funasrruntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ _FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true,
int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS,
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true);
const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);

Expand Down
6 changes: 3 additions & 3 deletions runtime/onnxruntime/src/audio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ float Audio::GetTimeLen()
void Audio::WavResample(int32_t sampling_rate, const float *waveform,
int32_t n)
{
LOG(INFO) << "Creating a resampler:\n"
<< " in_sample_rate: "<< sampling_rate << "\n"
<< " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate);
LOG(INFO) << "Creating a resampler: "
<< " in_sample_rate: "<< sampling_rate
<< " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate);
float min_freq =
std::min<int32_t>(sampling_rate, dest_sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
Expand Down
17 changes: 14 additions & 3 deletions runtime/onnxruntime/src/funasrruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
int sampling_rate, std::string wav_format, ASR_TYPE mode,
const std::vector<std::vector<float>> &hw_emb, bool itn)
const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
{
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
Expand Down Expand Up @@ -511,7 +511,12 @@
// timestamp
std::string cur_stamp = "[";
while(audio->FetchTpass(frame) > 0){
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb);
// dec reset
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);

std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
if(msg_vec.size()==0){
Expand Down Expand Up @@ -761,9 +766,15 @@
if (asr_type == ASR_OFFLINE) {
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
} else if (asr_type == ASR_TWO_PASS){
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale);
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
Expand Down
11 changes: 9 additions & 2 deletions runtime/onnxruntime/src/paraformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ void Paraformer::InitLm(const std::string &lm_file,
lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
fst::Fst<fst::StdArc>::Read(lm_file));
if (lm_){
if (vocab) { delete vocab; }
vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
LOG(INFO) << "Successfully load lm file " << lm_file;
}else{
LOG(ERROR) << "Failed to load lm file " << lm_file;
Expand Down Expand Up @@ -310,6 +309,9 @@ Paraformer::~Paraformer()
if(vocab){
delete vocab;
}
if(lm_vocab){
delete lm_vocab;
}
if(seg_dict){
delete seg_dict;
}
Expand Down Expand Up @@ -687,6 +689,11 @@ Vocab* Paraformer::GetVocab()
return vocab;
}

Vocab* Paraformer::GetLmVocab()
{
return lm_vocab;
}

PhoneSet* Paraformer::GetPhoneSet()
{
return phone_set_;
Expand Down
2 changes: 2 additions & 0 deletions runtime/onnxruntime/src/paraformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace funasr {
*/
private:
Vocab* vocab = nullptr;
Vocab* lm_vocab = nullptr;
SegDict* seg_dict = nullptr;
PhoneSet* phone_set_ = nullptr;
//const float scale = 22.6274169979695;
Expand Down Expand Up @@ -65,6 +66,7 @@ namespace funasr {
string FinalizeDecode(WfstDecoder* &wfst_decoder,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
Vocab* GetVocab();
Vocab* GetLmVocab();
PhoneSet* GetPhoneSet();

knf::FbankOptions fbank_opts_;
Expand Down
14 changes: 14 additions & 0 deletions runtime/onnxruntime/src/tpass-stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thr
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
exit(-1);
}

// Lm resource
if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
string fst_path, lm_config_path, lex_path;
fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
if (access(lex_path.c_str(), F_OK) != 0 )
{
LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
}else{
asr_handle->InitLm(fst_path, lm_config_path, lex_path);
}
}

// PUNC model
if(model_path.find(PUNC_DIR) != model_path.end()){
Expand Down
5 changes: 4 additions & 1 deletion runtime/websocket/bin/funasr-wss-client-2pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ class WebsocketClient {
funasr::Audio audio(1);
int32_t sampling_rate = audio_fs;
std::string wav_format = "pcm";
if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false))
return;
} else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return;
} else {
wav_format = "others";
Expand Down
Loading

0 comments on commit b882590

Please sign in to comment.