diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index d4e69046da..298d59616e 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -216,6 +216,18 @@ def get_args(): """, ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + parser.add_argument( "sound_files", type=str, @@ -290,6 +302,7 @@ def main(): lm_scale=args.lm_scale, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, ) elif args.zipformer2_ctc: recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index 2ef0aee7f6..3a2ff3b8f3 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -102,6 +102,17 @@ def get_args(): """, ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) return parser.parse_args() @@ -130,6 +141,7 @@ def create_recognizer(args): provider=args.provider, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, ) return recognizer diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index a5aecb67d4..e4fb1d1d80 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -111,6 +111,17 @@ def get_args(): """, ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) return parser.parse_args() @@ -136,6 +147,7 @@ def create_recognizer(args): provider=args.provider, hotwords_file=args.hotwords_file, hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, ) return recognizer diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index a06a7dedad..b5f500e08f 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): """, ) +def add_blank_penalty_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) def add_endpointing_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -284,6 +296,7 @@ def get_args(): add_decoding_args(parser) add_endpointing_args(parser) add_hotwords_args(parser) + add_blank_penalty_args(parser) parser.add_argument( "--port", @@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: max_active_paths=args.num_active_paths, hotwords_score=args.hotwords_score, hotwords_file=args.hotwords_file, + blank_penalty=args.blank_penalty, enable_endpoint_detection=args.use_endpoint != 0, rule1_min_trailing_silence=args.rule1_min_trailing_silence, rule2_min_trailing_silence=args.rule2_min_trailing_silence, diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 3ac757ea74..8b193a67ea 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, unk_id_); + config_.lm_config.scale, unk_id_, config_.blank_penalty); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), unk_id_); + model_.get(), unk_id_, config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); @@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, - config_.lm_config.scale, unk_id_); + config_.lm_config.scale, unk_id_, config_.blank_penalty); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( - model_.get(), unk_id_); + model_.get(), unk_id_, config_.blank_penalty); } else { SHERPA_ONNX_LOGE("Unsupported decoding method: %s", config.decoding_method.c_str()); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index c3fb728df1..9cf7930627 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "True to enable endpoint detection. False to disable it."); po->Register("max-active-paths", &max_active_paths, "beam size used in modified beam search."); + po->Register("blank-penalty", &blank_penalty, + "The penalty applied on blank symbol during decoding. " + "Note: It is a positive value. " + "Increasing value will lead to lower deletion at the cost" + "of higher insertions. " + "Currently only applicable for transducer models."); po->Register("hotwords-score", &hotwords_score, "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); @@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_score=" << hotwords_score << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; - os << "decoding_method=\"" << decoding_method << "\")"; + os << "decoding_method=\"" << decoding_method << "\", "; + os << "blank_penalty=" << blank_penalty << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index a8b173851d..f0580d9cf6 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -83,6 +83,8 @@ struct OnlineRecognizerConfig { float hotwords_score = 1.5; std::string hotwords_file; + float blank_penalty = 0.0; + OnlineRecognizerConfig() = default; OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, @@ -92,7 +94,8 @@ struct OnlineRecognizerConfig { bool enable_endpoint, const std::string &decoding_method, int32_t max_active_paths, - const std::string &hotwords_file, float hotwords_score) + const std::string &hotwords_file, float hotwords_score, + float blank_penalty) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -101,7 +104,8 @@ struct OnlineRecognizerConfig { decoding_method(decoding_method), max_active_paths(max_active_paths), hotwords_score(hotwords_score), - hotwords_file(hotwords_file) {} + hotwords_file(hotwords_file), + blank_penalty(blank_penalty) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 132aa87d25..e79a8e2f4f 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode( Ort::Value logit = model_->RunJoiner( std::move(cur_encoder_out), View(&decoder_out)); - const float *p_logit = logit.GetTensorData(); + float *p_logit = logit.GetTensorMutableData(); bool emitted = false; for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { auto &r = (*result)[i]; + if (blank_penalty_ > 0.0) { + p_logit[0] -= blank_penalty_; // assuming blank id is 0 + } auto y = static_cast(std::distance( static_cast(p_logit), std::max_element(static_cast(p_logit), diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h index 363cefeddc..dd9faf8e8d 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -15,8 +15,9 @@ namespace sherpa_onnx { class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { public: OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, - int32_t unk_id) - : model_(model), unk_id_(unk_id) {} + int32_t unk_id, + float blank_penalty) + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { private: OnlineTransducerModel *model_; // Not owned int32_t unk_id_; + float blank_penalty_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 2694b4bd12..f676f45da9 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); float *p_logit = logit.GetTensorMutableData(); + if (blank_penalty_ > 0.0) { + // assuming blank id is 0 + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); + } LogSoftmax(p_logit, vocab_size, num_hyps); // now p_logit contains log_softmax output, we rename it to p_logprob diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h index bc0cfb5595..92e9a69c9d 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h @@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, OnlineLM *lm, int32_t max_active_paths, - float lm_scale, int32_t unk_id) + float lm_scale, int32_t unk_id, + float blank_penalty) : model_(model), lm_(lm), max_active_paths_(max_active_paths), lm_scale_(lm_scale), - unk_id_(unk_id) {} + unk_id_(unk_id), + blank_penalty_(blank_penalty) {} OnlineTransducerDecoderResult GetEmptyResult() const override; @@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder int32_t max_active_paths_; float lm_scale_; // used only when lm_ is not nullptr int32_t unk_id_; + float blank_penalty_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index d77be38671..c40b541cd0 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { py::class_(*m, "OnlineRecognizerConfig") .def(py::init(), + const std::string &, int32_t, const std::string &, float, + float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), py::arg("enable_endpoint"), py::arg("decoding_method"), py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0) + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index a9f9e7d59e..c92b605fdf 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -48,6 +48,7 @@ def from_transducer( decoding_method: str = "greedy_search", max_active_paths: int = 4, hotwords_score: float = 1.5, + blank_penalty: float = 0.0, hotwords_file: str = "", provider: str = "cpu", model_type: str = "", @@ -100,6 +101,8 @@ def from_transducer( max_active_paths: Use only when decoding_method is modified_beam_search. It specifies the maximum number of active paths during beam search. + blank_penalty: + The penalty applied on blank symbol during decoding. hotwords_file: The file containing hotwords, one words/phrases per line, and for each phrase the bpe/cjkchar are separated by a space. @@ -172,6 +175,7 @@ def from_transducer( max_active_paths=max_active_paths, hotwords_score=hotwords_score, hotwords_file=hotwords_file, + blank_penalty=blank_penalty, ) self.recognizer = _Recognizer(recognizer_config)