diff --git a/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp index e591b322d..d49ba7231 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp @@ -44,12 +44,17 @@ void GetValue(TCLAP::ValueArg& value_arg, string key, std::map chunk_size, vector wav_list, vector wav_ids, int audio_fs, - float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_) { + float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_, + float glob_beam, float lat_beam, float am_scale, int inc_bias, unordered_map hws_map) { struct timeval start, end; long seconds = 0; float n_total_length = 0.0f; long n_total_time = 0; + + FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_scale); + // load hotwords list and build graph + FunWfstDecoderLoadHwsRes(decoder_handle, inc_bias, hws_map); std::vector> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS); @@ -90,7 +95,8 @@ void runReg(FUNASR_HANDLE tpass_handle, std::vector chunk_size, vector chunk_size, vector chunk_size, vector punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); TCLAP::ValueArg punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); TCLAP::ValueArg itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string"); + TCLAP::ValueArg lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); + TCLAP::ValueArg global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); + TCLAP::ValueArg lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); + TCLAP::ValueArg am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); + TCLAP::ValueArg fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t"); TCLAP::ValueArg asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string"); TCLAP::ValueArg onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); @@ -231,6 +245,11 @@ int main(int argc, char** argv) cmd.add(punc_dir); cmd.add(punc_quant); cmd.add(itn_dir); + cmd.add(lm_dir); + cmd.add(global_beam); + cmd.add(lattice_beam); + cmd.add(am_scale); + cmd.add(fst_inc_wts); cmd.add(wav_path); cmd.add(audio_fs); cmd.add(asr_mode); @@ -248,6 +267,7 @@ int main(int argc, char** argv) GetValue(punc_dir, PUNC_DIR, model_path); GetValue(punc_quant, PUNC_QUANT, model_path); GetValue(itn_dir, ITN_DIR, model_path); + GetValue(lm_dir, LM_DIR, model_path); GetValue(wav_path, WAV_PATH, model_path); GetValue(asr_mode, ASR_MODE, model_path); @@ -272,6 +292,14 @@ int main(int argc, char** argv) LOG(ERROR) << "FunTpassInit init failed"; exit(-1); } + float glob_beam = 3.0f; + float lat_beam = 3.0f; + float am_sc = 10.0f; + if (lm_dir.isSet()) { + glob_beam = global_beam.getValue(); + lat_beam = lattice_beam.getValue(); + am_sc = am_scale.getValue(); + } gettimeofday(&end, NULL); long seconds = (end.tv_sec - start.tv_sec); @@ -321,7 +349,8 @@ int main(int argc, char** argv) int rtf_threds = thread_num_.getValue(); for (int i = 0; i < rtf_threds; i++) { - threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_)); + threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_, + glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map)); } for (auto& thread : threads) diff --git a/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp b/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp index b21092771..abcc4b2e7 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp @@ -51,6 +51,11 @@ int main(int argc, char** argv) TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); TCLAP::ValueArg punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); TCLAP::ValueArg itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string"); + TCLAP::ValueArg lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); + TCLAP::ValueArg global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); + TCLAP::ValueArg lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); + TCLAP::ValueArg am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); + TCLAP::ValueArg fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t"); TCLAP::ValueArg asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string"); TCLAP::ValueArg onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); @@ -65,6 +70,11 @@ int main(int argc, char** argv) cmd.add(vad_quant); cmd.add(punc_dir); cmd.add(punc_quant); + cmd.add(lm_dir); + cmd.add(global_beam); + cmd.add(lattice_beam); + cmd.add(am_scale); + cmd.add(fst_inc_wts); cmd.add(itn_dir); cmd.add(wav_path); cmd.add(audio_fs); @@ -81,6 +91,7 @@ int main(int argc, char** argv) GetValue(vad_quant, VAD_QUANT, model_path); GetValue(punc_dir, PUNC_DIR, model_path); GetValue(punc_quant, PUNC_QUANT, model_path); + GetValue(lm_dir, LM_DIR, model_path); GetValue(itn_dir, ITN_DIR, model_path); GetValue(wav_path, WAV_PATH, model_path); GetValue(asr_mode, ASR_MODE, model_path); @@ -106,6 +117,16 @@ int main(int argc, char** argv) LOG(ERROR) << "FunTpassInit init failed"; exit(-1); } + float glob_beam = 3.0f; + float lat_beam = 3.0f; + float am_sc = 10.0f; + if (lm_dir.isSet()) { + glob_beam = global_beam.getValue(); + lat_beam = lattice_beam.getValue(); + am_sc = am_scale.getValue(); + } + // init wfst decoder + FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc); gettimeofday(&end, NULL); long seconds = (end.tv_sec - start.tv_sec); @@ -146,6 +167,9 @@ int main(int argc, char** argv) wav_ids.emplace_back(default_id); } + // load hotwords list and build graph + FunWfstDecoderLoadHwsRes(decoder_handle, fst_inc_wts.getValue(), hws_map); + std::vector> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS); // init online features std::vector chunk_size = {5,10,5}; @@ -191,7 +215,9 @@ int main(int argc, char** argv) is_final = false; } gettimeofday(&start, NULL); - FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding); + FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, + speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", + (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle); gettimeofday(&end, NULL); seconds = (end.tv_sec - start.tv_sec); taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); @@ -235,10 +261,12 @@ int main(int argc, char** argv) } } } - + + FunWfstDecoderUnloadHwsRes(decoder_handle); LOG(INFO) << "Audio length: " << (double)snippet_time << " s"; LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s"; LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000); + FunASRWfstDecoderUninit(decoder_handle); FunTpassOnlineUninit(tpass_online_handle); FunTpassUninit(tpass_handle); return 0; diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp index b1a7c870c..83d7e79d8 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp @@ -54,7 +54,6 @@ void runReg(FUNASR_HANDLE asr_handle, vector wav_list, vector wa // warm up for (size_t i = 0; i < 1; i++) { - FunOfflineReset(asr_handle, decoder_handle); FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs, true, decoder_handle); if(result){ FunASRFreeResult(result); diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp index 55eda9395..4aaa0023e 100644 --- a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp +++ b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp @@ -50,7 +50,7 @@ int main(int argc, char** argv) TCLAP::ValueArg vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); TCLAP::ValueArg punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); TCLAP::ValueArg punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); - TCLAP::ValueArg lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml ", false, "", "string"); + TCLAP::ValueArg lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); TCLAP::ValueArg global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); TCLAP::ValueArg lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); TCLAP::ValueArg am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h index 27ee6c6ba..cff617f38 100644 --- a/runtime/onnxruntime/include/funasrruntime.h +++ b/runtime/onnxruntime/include/funasrruntime.h @@ -119,7 +119,7 @@ _FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std:: _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector> &punc_cache, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS, - const std::vector> &hw_emb={{0.0}}, bool itn=true); + const std::vector> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr); _FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle); _FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle); diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp index 559e3ddfc..c471329d0 100644 --- a/runtime/onnxruntime/src/audio.cpp +++ b/runtime/onnxruntime/src/audio.cpp @@ -254,9 +254,9 @@ float Audio::GetTimeLen() void Audio::WavResample(int32_t sampling_rate, const float *waveform, int32_t n) { - LOG(INFO) << "Creating a resampler:\n" - << " in_sample_rate: "<< sampling_rate << "\n" - << " output_sample_rate: " << static_cast(dest_sample_rate); + LOG(INFO) << "Creating a resampler: " + << " in_sample_rate: "<< sampling_rate + << " output_sample_rate: " << static_cast(dest_sample_rate); float min_freq = std::min(sampling_rate, dest_sample_rate); float lowpass_cutoff = 0.99 * 0.5 * min_freq; diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp index ccd0412b5..c4cb9d935 100644 --- a/runtime/onnxruntime/src/funasrruntime.cpp +++ b/runtime/onnxruntime/src/funasrruntime.cpp @@ -437,7 +437,7 @@ _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode, - const std::vector> &hw_emb, bool itn) + const std::vector> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle) { funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle; funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle; @@ -511,7 +511,12 @@ // timestamp std::string cur_stamp = "["; while(audio->FetchTpass(frame) > 0){ - string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb); + // dec reset + funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle; + if (wfst_decoder){ + wfst_decoder->StartUtterance(); + } + string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle); std::vector msg_vec = funasr::split(msg, '|'); // split with timestamp if(msg_vec.size()==0){ @@ -761,9 +766,15 @@ if (asr_type == ASR_OFFLINE) { funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle; funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get(); + if (paraformer->lm_) + mm = new funasr::WfstDecoder(paraformer->lm_.get(), + paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale); + } else if (asr_type == ASR_TWO_PASS){ + funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle; + funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get(); if (paraformer->lm_) mm = new funasr::WfstDecoder(paraformer->lm_.get(), - paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale); + paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale); } return mm; } diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp index bb15ac723..c56421cca 100644 --- a/runtime/onnxruntime/src/paraformer.cpp +++ b/runtime/onnxruntime/src/paraformer.cpp @@ -193,8 +193,7 @@ void Paraformer::InitLm(const std::string &lm_file, lm_ = std::shared_ptr>( fst::Fst::Read(lm_file)); if (lm_){ - if (vocab) { delete vocab; } - vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str()); + lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str()); LOG(INFO) << "Successfully load lm file " << lm_file; }else{ LOG(ERROR) << "Failed to load lm file " << lm_file; @@ -310,6 +309,9 @@ Paraformer::~Paraformer() if(vocab){ delete vocab; } + if(lm_vocab){ + delete lm_vocab; + } if(seg_dict){ delete seg_dict; } @@ -687,6 +689,11 @@ Vocab* Paraformer::GetVocab() return vocab; } +Vocab* Paraformer::GetLmVocab() +{ + return lm_vocab; +} + PhoneSet* Paraformer::GetPhoneSet() { return phone_set_; diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h index de0565773..5bb9477bf 100644 --- a/runtime/onnxruntime/src/paraformer.h +++ b/runtime/onnxruntime/src/paraformer.h @@ -20,6 +20,7 @@ namespace funasr { */ private: Vocab* vocab = nullptr; + Vocab* lm_vocab = nullptr; SegDict* seg_dict = nullptr; PhoneSet* phone_set_ = nullptr; //const float scale = 22.6274169979695; @@ -65,6 +66,7 @@ namespace funasr { string FinalizeDecode(WfstDecoder* &wfst_decoder, bool is_stamp=false, std::vector us_alphas={0}, std::vector us_cif_peak={0}); Vocab* GetVocab(); + Vocab* GetLmVocab(); PhoneSet* GetPhoneSet(); knf::FbankOptions fbank_opts_; diff --git a/runtime/onnxruntime/src/tpass-stream.cpp b/runtime/onnxruntime/src/tpass-stream.cpp index a3e1b0eb3..b723e0fa1 100644 --- a/runtime/onnxruntime/src/tpass-stream.cpp +++ b/runtime/onnxruntime/src/tpass-stream.cpp @@ -66,6 +66,20 @@ TpassStream::TpassStream(std::map& model_path, int thr LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir"; exit(-1); } + + // Lm resource + if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") { + string fst_path, lm_config_path, lex_path; + fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES); + lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME); + lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH); + if (access(lex_path.c_str(), F_OK) != 0 ) + { + LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model."; + }else{ + asr_handle->InitLm(fst_path, lm_config_path, lex_path); + } + } // PUNC model if(model_path.find(PUNC_DIR) != model_path.end()){ diff --git a/runtime/websocket/bin/funasr-wss-client-2pass.cpp b/runtime/websocket/bin/funasr-wss-client-2pass.cpp index 6533dd556..0cbd10e23 100644 --- a/runtime/websocket/bin/funasr-wss-client-2pass.cpp +++ b/runtime/websocket/bin/funasr-wss-client-2pass.cpp @@ -192,7 +192,10 @@ class WebsocketClient { funasr::Audio audio(1); int32_t sampling_rate = audio_fs; std::string wav_format = "pcm"; - if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) { + if (funasr::IsTargetFile(wav_path.c_str(), "wav")) { + if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false)) + return; + } else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) { if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return; } else { wav_format = "others"; diff --git a/runtime/websocket/bin/funasr-wss-server-2pass.cpp b/runtime/websocket/bin/funasr-wss-server-2pass.cpp index 965f2a8c9..ef27d5b4a 100644 --- a/runtime/websocket/bin/funasr-wss-server-2pass.cpp +++ b/runtime/websocket/bin/funasr-wss-server-2pass.cpp @@ -16,6 +16,7 @@ // hotwords std::unordered_map hws_map_; int fst_inc_wts_=20; +float global_beam_, lattice_beam_, am_scale_; using namespace std; void GetValue(TCLAP::ValueArg& value_arg, string key, @@ -120,6 +121,14 @@ int main(int argc, char* argv[]) { "connection", false, "../../../ssl_key/server.key", "string"); + TCLAP::ValueArg global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); + TCLAP::ValueArg lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); + TCLAP::ValueArg am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); + + TCLAP::ValueArg lm_dir("", LM_DIR, + "the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string"); + TCLAP::ValueArg lm_revision( + "", "lm-revision", "LM model revision", false, "v1.0.2", "string"); TCLAP::ValueArg hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "/workspace/resources/hotwords.txt", "string"); @@ -128,6 +137,10 @@ int main(int argc, char* argv[]) { // add file cmd.add(hotword); + cmd.add(fst_inc_wts); + cmd.add(global_beam); + cmd.add(lattice_beam); + cmd.add(am_scale); cmd.add(certfile); cmd.add(keyfile); @@ -146,6 +159,8 @@ int main(int argc, char* argv[]) { cmd.add(punc_quant); cmd.add(itn_dir); cmd.add(itn_revision); + cmd.add(lm_dir); + cmd.add(lm_revision); cmd.add(listen_ip); cmd.add(port); @@ -163,6 +178,7 @@ int main(int argc, char* argv[]) { GetValue(punc_dir, PUNC_DIR, model_path); GetValue(punc_quant, PUNC_QUANT, model_path); GetValue(itn_dir, ITN_DIR, model_path); + GetValue(lm_dir, LM_DIR, model_path); GetValue(hotword, HOTWORD, model_path); GetValue(offline_model_revision, "offline-model-revision", model_path); @@ -170,6 +186,11 @@ int main(int argc, char* argv[]) { GetValue(vad_revision, "vad-revision", model_path); GetValue(punc_revision, "punc-revision", model_path); GetValue(itn_revision, "itn-revision", model_path); + GetValue(lm_revision, "lm-revision", model_path); + + global_beam_ = global_beam.getValue(); + lattice_beam_ = lattice_beam.getValue(); + am_scale_ = am_scale.getValue(); // Download model form Modelscope try { @@ -183,6 +204,7 @@ int main(int argc, char* argv[]) { std::string s_punc_path = model_path[PUNC_DIR]; std::string s_punc_quant = model_path[PUNC_QUANT]; std::string s_itn_path = model_path[ITN_DIR]; + std::string s_lm_path = model_path[LM_DIR]; std::string python_cmd = "python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True "; @@ -241,11 +263,18 @@ int main(int argc, char* argv[]) { size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404"); if (found != std::string::npos) { model_path["offline-model-revision"]="v1.2.4"; - } else{ - found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"); - if (found != std::string::npos) { - model_path["offline-model-revision"]="v1.0.5"; - } + } + + found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"); + if (found != std::string::npos) { + model_path["offline-model-revision"]="v1.0.5"; + } + + found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020"); + if (found != std::string::npos) { + model_path["model-revision"]="v1.0.0"; + s_itn_path=""; + s_lm_path=""; } if (access(s_offline_asr_path.c_str(), F_OK) == 0) { @@ -332,6 +361,49 @@ int main(int argc, char* argv[]) { LOG(INFO) << "ASR online model is not set, use default."; } + if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") { + std::string python_cmd_lm; + std::string down_lm_path; + std::string down_lm_model; + + if (access(s_lm_path.c_str(), F_OK) == 0) { + // local + python_cmd_lm = python_cmd + " --model-name " + s_lm_path + + " --export-dir ./ " + " --model_revision " + + model_path["lm-revision"] + " --export False "; + down_lm_path = s_lm_path; + } else { + // modelscope + LOG(INFO) << "Download model: " << s_lm_path + << " from modelscope : "; + python_cmd_lm = python_cmd + " --model-name " + + s_lm_path + + " --export-dir " + s_download_model_dir + + " --model_revision " + model_path["lm-revision"] + + " --export False "; + down_lm_path = + s_download_model_dir + + "/" + s_lm_path; + } + + int ret = system(python_cmd_lm.c_str()); + if (ret != 0) { + LOG(INFO) << "Failed to download model from modelscope. If you set local lm model path, you can ignore the errors."; + } + down_lm_model = down_lm_path + "/TLG.fst"; + + if (access(down_lm_model.c_str(), F_OK) != 0) { + LOG(ERROR) << down_lm_model << " do not exists."; + exit(-1); + } else { + model_path[LM_DIR] = down_lm_path; + LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR]; + } + } else { + LOG(INFO) << "LM model is not set, not executed."; + model_path[LM_DIR] = ""; + } + if (!s_punc_path.empty()) { std::string python_cmd_punc; std::string down_punc_path; diff --git a/runtime/websocket/bin/websocket-server-2pass.cpp b/runtime/websocket/bin/websocket-server-2pass.cpp index 44dd82e5e..0269e5ff2 100644 --- a/runtime/websocket/bin/websocket-server-2pass.cpp +++ b/runtime/websocket/bin/websocket-server-2pass.cpp @@ -18,6 +18,7 @@ extern std::unordered_map hws_map_; extern int fst_inc_wts_; +extern float global_beam_, lattice_beam_, am_scale_; context_ptr WebSocketServer::on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl, @@ -102,7 +103,8 @@ void WebSocketServer::do_decoder( bool itn, int audio_fs, std::string wav_format, - FUNASR_HANDLE& tpass_online_handle) { + FUNASR_HANDLE& tpass_online_handle, + FUNASR_DEC_HANDLE& decoder_handle) { // lock for each connection if(!tpass_online_handle){ scoped_lock guard(thread_lock); @@ -131,7 +133,7 @@ void WebSocketServer::do_decoder( subvector.data(), subvector.size(), punc_cache, false, audio_fs, wav_format, (ASR_TYPE)asr_mode_, - hotwords_embedding, itn); + hotwords_embedding, itn, decoder_handle); } else { scoped_lock guard(thread_lock); @@ -168,7 +170,7 @@ void WebSocketServer::do_decoder( buffer.data(), buffer.size(), punc_cache, is_final, audio_fs, wav_format, (ASR_TYPE)asr_mode_, - hotwords_embedding, itn); + hotwords_embedding, itn, decoder_handle); } else { scoped_lock guard(thread_lock); msg["access_num"]=(int)msg["access_num"]-1; @@ -241,6 +243,9 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) { data_msg->msg["audio_fs"] = 16000; // default is 16k data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly data_msg->msg["is_eof"]=false; // if this connection is closed + FUNASR_DEC_HANDLE decoder_handle = + FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_); + data_msg->decoder_handle = decoder_handle; data_msg->punc_cache = std::make_shared>>(2); data_msg->strand_ = std::make_shared(io_decoder_); @@ -267,6 +272,9 @@ void remove_hdl( // finished and avoid access freed tpass_online_handle unique_lock guard_decoder(*(data_msg->thread_lock)); if (data_msg->msg["access_num"]==0 && data_msg->msg["is_eof"]==true) { + FunWfstDecoderUnloadHwsRes(data_msg->decoder_handle); + FunASRWfstDecoderUninit(data_msg->decoder_handle); + data_msg->decoder_handle = nullptr; FunTpassOnlineUninit(data_msg->tpass_online_handle); data_msg->tpass_online_handle = nullptr; data_map.erase(hdl); @@ -431,7 +439,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, nn_hotwords += " " + pair.first; LOG(INFO) << pair.first << " : " << pair.second; } - // FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map); + FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map); // nn std::vector> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords, ASR_TWO_PASS); @@ -483,7 +491,8 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, msg_data->msg["itn"], msg_data->msg["audio_fs"], msg_data->msg["wav_format"], - std::ref(msg_data->tpass_online_handle))); + std::ref(msg_data->tpass_online_handle), + std::ref(msg_data->decoder_handle))); msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1; } catch (std::exception const &e) @@ -530,7 +539,8 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl, msg_data->msg["itn"], msg_data->msg["audio_fs"], msg_data->msg["wav_format"], - std::ref(msg_data->tpass_online_handle))); + std::ref(msg_data->tpass_online_handle), + std::ref(msg_data->decoder_handle))); msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1; } } diff --git a/runtime/websocket/bin/websocket-server-2pass.h b/runtime/websocket/bin/websocket-server-2pass.h index 3e78a3438..6b2ba325f 100644 --- a/runtime/websocket/bin/websocket-server-2pass.h +++ b/runtime/websocket/bin/websocket-server-2pass.h @@ -60,7 +60,8 @@ typedef struct { FUNASR_HANDLE tpass_online_handle=NULL; std::string online_res = ""; std::string tpass_res = ""; - std::shared_ptr strand_; // for data execute in order + std::shared_ptr strand_; // for data execute in order + FUNASR_DEC_HANDLE decoder_handle=NULL; } FUNASR_MESSAGE; // See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about @@ -123,7 +124,8 @@ class WebSocketServer { bool itn, int audio_fs, std::string wav_format, - FUNASR_HANDLE& tpass_online_handle); + FUNASR_HANDLE& tpass_online_handle, + FUNASR_DEC_HANDLE& decoder_handle); void initAsr(std::map& model_path, int thread_num); void on_message(websocketpp::connection_hdl hdl, message_ptr msg);