From a4c84ae56eb3eec4b6f4ef30b35d168f168eafc4 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 22 Apr 2024 10:57:01 -0400 Subject: [PATCH] Update suppress_sentences in en-US.ini and transcription-filter-data.h --- data/locale/en-US.ini | 1 + src/transcription-filter-data.h | 1 + src/transcription-filter.cpp | 7 +++++- src/transcription-filter.h | 2 ++ src/whisper-utils/token-buffer-thread.cpp | 18 ++++++++++---- src/whisper-utils/whisper-processing.cpp | 25 +++++++++++++++---- src/whisper-utils/whisper-processing.h | 1 + src/whisper-utils/whisper-utils.cpp | 29 +++++++++++++++++++---- 8 files changed, 70 insertions(+), 14 deletions(-) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index 6f7c1e9..c294eae 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -51,3 +51,4 @@ translate_add_context="Translate with context" whisper_translate="Translate to English (Whisper)" buffer_size_msec="Buffer size (ms)" overlap_size_msec="Overlap size (ms)" +suppress_sentences="Suppress sentences (each line)" diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 09d88b4..fea8d33 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -80,6 +80,7 @@ struct transcription_filter_data { std::string target_lang; bool buffered_output = false; bool enable_token_ts_dtw = false; + std::string suppress_sentences; // Last transcription result std::string last_text; diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index bf4a48a..4fa13bf 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -363,6 +363,7 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->source_lang = obs_data_get_string(s, "translate_source_language"); gf->target_lang = obs_data_get_string(s, "translate_target_language"); gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context"); + gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences"); if (new_translate != gf->translate) { if (new_translate) { @@ -610,7 +611,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) send_caption_to_source(text, gf); } }, - 20, + 30, std::chrono::seconds(10)); obs_log(gf->log_level, "run update"); @@ -723,6 +724,7 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_string(s, "translate_target_language", "__es__"); obs_data_set_default_string(s, "translate_source_language", "__en__"); obs_data_set_default_bool(s, "translate_add_context", true); + obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT); // Whisper parameters obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); @@ -922,6 +924,9 @@ obs_properties_t *transcription_filter_properties(void *data) obs_property_list_add_int(list, "INFO", LOG_INFO); obs_property_list_add_int(list, "WARNING", LOG_WARNING); + // add a text input for sentences to suppress + obs_properties_add_text(ppts, "suppress_sentences", MT_("suppress_sentences"), OBS_TEXT_MULTILINE); + obs_properties_t *whisper_params_group = obs_properties_create(); obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"), OBS_GROUP_NORMAL, whisper_params_group); diff --git a/src/transcription-filter.h b/src/transcription-filter.h index 6784540..c79a37a 100644 --- a/src/transcription-filter.h +++ b/src/transcription-filter.h @@ -19,6 +19,8 @@ const char *const PLUGIN_INFO_TEMPLATE = "OCC AI ❤️ " "Support & Follow"; +const char* const SUPPRESS_SENTENCES_DEFAULT = "Thank you for watching\nThank you"; + #ifdef __cplusplus } #endif diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index ff67953..fd7083b 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -45,7 +45,7 @@ void TokenBufferThread::addWords(const std::vector &words) log_token_vector(words); // run reconstructSentence - std::vector reconstructed = reconstructSentence(currentWords, words); + std::vector reconstructed = reconstructSentence(currentWords, words); log_token_vector(reconstructed); @@ -56,6 +56,8 @@ void TokenBufferThread::addWords(const std::vector &words) for (const auto &word : reconstructed) { wordQueue.push_back(word); } + + newDataAvailable = true; } condVar.notify_all(); } @@ -102,11 +104,19 @@ void TokenBufferThread::monitor() this->wordQueue.push_front(*it); } - if (this->wordQueue.size() >= this->maxSize || - std::chrono::steady_clock::now() - startTime >= this->maxTime) { + // check if we need to flush the queue + auto elapsedTime = std::chrono::duration_cast( + std::chrono::steady_clock::now() - startTime); + if (this->wordQueue.size() >= this->maxSize || elapsedTime >= this->maxTime) { // flush the queue if it's full or we've reached the max time size_t words_to_flush = std::min(this->wordQueue.size(), this->maxSize); + // make sure we leave at least 3 words in the queue + size_t words_remaining = this->wordQueue.size() - words_to_flush; + if (words_remaining < 3) { + words_to_flush -= 3 - words_remaining; + } + obs_log(LOG_INFO, "TokenBufferThread::monitor: flushing %d words", words_to_flush); for (size_t i = 0; i < words_to_flush; ++i) { wordQueue.pop_front(); } @@ -116,4 +126,4 @@ void TokenBufferThread::monitor() newDataAvailable = false; } obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); -} \ No newline at end of file +} diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 5d7cc61..11f9f4f 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -296,12 +296,12 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter std::string(token_str) + "), "; tokens.push_back(token); } - obs_log(LOG_INFO, "Token %d: %d, %s, p: %.3f, dtw: %ld [keep: %d]", j, + obs_log(gf->log_level, "Token %d: %d, %s, p: %.3f, dtw: %ld [keep: %d]", j, token.id, token_str, token.p, token.t_dtw, keep); } sentence_p /= (float)n_tokens; - obs_log(LOG_INFO, "Decoded sentence: '%s'", text.c_str()); - obs_log(LOG_INFO, "Token IDs: %s", tokenIds.c_str()); + obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str()); + obs_log(gf->log_level, "Token IDs: %s", tokenIds.c_str()); // convert text to lowercase std::string text_lower(text); @@ -312,6 +312,21 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter .base(), text_lower.end()); + // if suppression is enabled, check if the text is in the suppression list + if (!gf->suppress_sentences.empty()) { + std::string suppress_sentences_copy = gf->suppress_sentences; + size_t pos = 0; + std::string token; + while ((pos = suppress_sentences_copy.find("\n")) != std::string::npos) { + token = suppress_sentences_copy.substr(0, pos); + suppress_sentences_copy.erase(0, pos + 1); + if (text_lower == suppress_sentences_copy) { + obs_log(gf->log_level, "Suppressing sentence: %s", text_lower.c_str()); + return {DETECTION_RESULT_SUPPRESSED, "", 0, 0, {}}; + } + } + } + if (gf->log_words) { obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), sentence_p, text_lower.c_str()); @@ -456,8 +471,8 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) if (last_step_in_segment) { const uint64_t overlap_size_ms = (uint64_t)(gf->overlap_frames * 1000 / gf->sample_rate); - obs_log(gf->log_level, - "copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning", + obs_log(gf->log_level, + "copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning", gf->overlap_frames, overlap_size_ms, gf->last_num_frames - gf->overlap_frames); for (size_t c = 0; c < gf->channels; c++) { // This is the last step in the segment - reset the copy buffer (include overlap frames) diff --git a/src/whisper-utils/whisper-processing.h b/src/whisper-utils/whisper-processing.h index 409aaab..358c390 100644 --- a/src/whisper-utils/whisper-processing.h +++ b/src/whisper-utils/whisper-processing.h @@ -12,6 +12,7 @@ enum DetectionResult { DETECTION_RESULT_UNKNOWN = 0, DETECTION_RESULT_SILENCE = 1, DETECTION_RESULT_SPEECH = 2, + DETECTION_RESULT_SUPPRESSED = 3, }; struct DetectionResultWithText { diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index 183afbc..732877b 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -164,7 +164,7 @@ std::pair findStartOfOverlap(const std::vector &se if (seq1.empty() || seq2.empty() || seq1.size() == 1 || seq2.size() == 1) { return {-1, -1}; } - for (int i = 0; i < seq1.size() - 1; ++i) { + for (int i = seq1.size() - 2; i >= seq1.size() / 2; --i) { for (int j = 0; j < seq2.size() - 1; ++j) { if (seq1[i].id == seq2[j].id) { // Check if the next token in both sequences is the same @@ -194,17 +194,38 @@ std::vector reconstructSentence(const std::vector reconstructed; if (overlap.first == -1 || overlap.second == -1) { + if (seq1.empty() && seq2.empty()) { + return reconstructed; + } + if (seq1.empty()) { + return seq2; + } + if (seq2.empty()) { + return seq1; + } + // Return concat of seq1 and seq2 if no overlap found - // check if the last token of seq1 == the first token of seq2 - if (!seq1.empty() && !seq2.empty() && seq1.back().id == seq2.front().id) { + if (seq1.back().id == seq2.front().id) { + // don't add the last token of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + } else if (seq2.size() > 1 && seq1.back().id == seq2[1].id) { + // check if the last token of seq1 == the second token of seq2 // don't add the last token of seq1 reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1); + // don't add the first token of seq2 + reconstructed.insert(reconstructed.end(), seq2.begin() + 1, seq2.end()); + } else if (seq1.size() > 1 && seq1[seq1.size() - 2].id == seq2.front().id) { + // check if the second to last token of seq1 == the first token of seq2 + // don't add the last two tokens of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 2); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); } else { // add all tokens of seq1 reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end()); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); } - reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); return reconstructed; }