Skip to content

Commit

Permalink
Add context biasing for mobile (k2-fsa#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
ductranminh authored Feb 1, 2024
1 parent 558f5e3 commit 665b869
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,17 @@ class SherpaOnnx(
acceptWaveform(ptr, samples, sampleRate)

fun inputFinished() = inputFinished(ptr)
fun reset(recreate: Boolean = false) = reset(ptr, recreate = recreate)
fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords)
fun decode() = decode(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr)
fun isReady(): Boolean = isReady(ptr)

val text: String
get() = getText(ptr)

val tokens: Array<String>
get() = getTokens(ptr)

private external fun delete(ptr: Long)

private external fun new(
Expand All @@ -107,10 +110,11 @@ class SherpaOnnx(
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String
private external fun reset(ptr: Long, recreate: Boolean)
private external fun reset(ptr: Long, recreate: Boolean, hotwords: String)
private external fun decode(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean
private external fun isReady(ptr: Long): Boolean
private external fun getTokens(ptr: Long): Array<String>

companion object {
init {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,17 @@ class SherpaOnnx(
acceptWaveform(ptr, samples, sampleRate)

fun inputFinished() = inputFinished(ptr)
fun reset(recreate: Boolean = false) = reset(ptr, recreate = recreate)
fun reset(recreate: Boolean = false, hotwords: String = "") = reset(ptr, recreate, hotwords)
fun decode() = decode(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr)
fun isReady(): Boolean = isReady(ptr)

val text: String
get() = getText(ptr)

val tokens: Array<String>
get() = getTokens(ptr)

private external fun delete(ptr: Long)

private external fun new(
Expand All @@ -142,10 +145,11 @@ class SherpaOnnx(
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String
private external fun reset(ptr: Long, recreate: Boolean)
private external fun reset(ptr: Long, recreate: Boolean, hotwords: String)
private external fun decode(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean
private external fun isReady(ptr: Long): Boolean
private external fun getTokens(ptr: Long): Array<String>

companion object {
init {
Expand Down
28 changes: 22 additions & 6 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,24 @@ class SherpaOnnx {

bool IsReady() const { return recognizer_.IsReady(stream_.get()); }

void Reset(bool recreate) {
if (recreate) {
stream_ = recognizer_.CreateStream();
// If keywords is an empty string, it just recreates the decoding stream
// If keywords is not empty, it will create a new decoding stream with
// the given keywords appended to the default keywords.
void Reset(bool recreate, const std::string &keywords = {}) {
if (keywords.empty()) {
if (recreate) {
stream_ = recognizer_.CreateStream();
} else {
recognizer_.Reset(stream_.get());
}
} else {
recognizer_.Reset(stream_.get());
auto stream = recognizer_.CreateStream(keywords);
// Set new keywords failed, the stream_ will not be updated.
if (stream != nullptr) {
stream_ = std::move(stream);
} else {
SHERPA_ONNX_LOGE("Failed to set keywords: %s", keywords.c_str());
}
}
}

Expand Down Expand Up @@ -1509,9 +1522,12 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete(

SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate) {
JNIEnv *env, jobject /*obj*/,
jlong ptr, jboolean recreate, jstring keywords) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Reset(recreate);
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
model->Reset(recreate, p_keywords);
env->ReleaseStringUTFChars(keywords, p_keywords);
}

SHERPA_ONNX_EXTERN_C
Expand Down
21 changes: 18 additions & 3 deletions swift-api-examples/SherpaOnnx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class SherpaOnnxOnlineRecongitionResult {
class SherpaOnnxRecognizer {
/// A pointer to the underlying counterpart in C
let recognizer: OpaquePointer!
let stream: OpaquePointer!
var stream: OpaquePointer!

/// Constructor taking a model config
init(
Expand Down Expand Up @@ -237,8 +237,23 @@ class SherpaOnnxRecognizer {

/// Reset the recognizer, which clears the neural network model state
/// and the state for decoding.
func reset() {
Reset(recognizer, stream)
/// If hotwords is an empty string, it just recreates the decoding stream
/// If hotwords is not empty, it will create a new decoding stream with
/// the given hotWords appended to the default hotwords.
func reset(hotwords: String? = nil) {
guard let words = hotwords, !words.isEmpty else {
Reset(recognizer, stream)
return
}

words.withCString { cString in
let newStream = CreateOnlineStreamWithHotwords(recognizer, cString)
// lock while release and replace stream
objc_sync_enter(self)
DestroyOnlineStream(stream)
stream = newStream
objc_sync_exit(self)
}
}

/// Signal that no more audio samples would be available.
Expand Down

0 comments on commit 665b869

Please sign in to comment.