diff --git a/buildspec.json b/buildspec.json index 8ba5638..a1babb6 100644 --- a/buildspec.json +++ b/buildspec.json @@ -38,7 +38,7 @@ }, "name": "obs-localvocal", "displayName": "OBS Localvocal", - "version": "0.2.3", + "version": "0.2.4", "author": "Roy Shilkrot", "website": "https://github.com/occ-ai/obs-localvocal", "email": "roy.shil@gmail.com", diff --git a/src/transcription-utils.cpp b/src/transcription-utils.cpp index c7f9d40..1cdde5a 100644 --- a/src/transcription-utils.cpp +++ b/src/transcription-utils.cpp @@ -2,6 +2,7 @@ #include #include +#include #define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) #define is_trail_byte(c) (((c)&0xc0) == 0x80) @@ -102,3 +103,16 @@ std::string remove_leading_trailing_nonalpha(const std::string &str) })); return str_copy; } + +std::vector split(const std::string &string, char delimiter) +{ + std::vector tokens; + std::string token; + std::istringstream tokenStream(string); + while (std::getline(tokenStream, token, delimiter)) { + if (!token.empty()) { + tokens.push_back(token); + } + } + return tokens; +} diff --git a/src/transcription-utils.h b/src/transcription-utils.h index 5e2e500..c4dce8a 100644 --- a/src/transcription-utils.h +++ b/src/transcription-utils.h @@ -2,8 +2,10 @@ #define TRANSCRIPTION_UTILS_H #include +#include std::string fix_utf8(const std::string &str); std::string remove_leading_trailing_nonalpha(const std::string &str); +std::vector split(const std::string &string, char delimiter); #endif // TRANSCRIPTION_UTILS_H diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 9970619..7d46275 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -6,6 +6,7 @@ #include "transcription-filter-data.h" #include "whisper-processing.h" #include "whisper-utils.h" +#include "transcription-utils.h" #include #include @@ -282,6 +283,10 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter if (token_str[0] == '[' && token_str[strlen(token_str) - 1] == ']') { keep = false; } + // if this is a special token, don't keep it + if (token.id >= 50256) { + keep = false; + } if ((j == n_tokens - 2 || j == n_tokens - 3) && token.p < 0.5) { keep = false; } @@ -312,20 +317,18 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter // if suppression is enabled, check if the text is in the suppression list if (!gf->suppress_sentences.empty()) { - std::string suppress_sentences_copy = gf->suppress_sentences; - size_t pos = 0; - std::string token; - while ((pos = suppress_sentences_copy.find("\n")) != std::string::npos) { - token = suppress_sentences_copy.substr(0, pos); - suppress_sentences_copy.erase(0, pos + 1); - if (text == suppress_sentences_copy) { - obs_log(gf->log_level, "Suppressing sentence: %s", + // split the suppression list by newline into individual sentences + std::vector suppress_sentences_list = + split(gf->suppress_sentences, '\n'); + // check if the text is in the suppression list + for (const std::string &suppress_sentence : suppress_sentences_list) { + if (text.find(suppress_sentence) != std::string::npos) { + obs_log(gf->log_level, "Suppressed sentence: '%s'", text.c_str()); - return {DETECTION_RESULT_SUPPRESSED, "", 0, 0, {}}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } } } - if (gf->log_words) { obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), sentence_p, text.c_str());