Skip to content

Commit

Permalink
Update suppress_sentences in en-US.ini and transcription-filter-data.h
Browse files Browse the repository at this point in the history
  • Loading branch information
royshil committed Apr 22, 2024
1 parent f0c33c0 commit a4c84ae
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 14 deletions.
1 change: 1 addition & 0 deletions data/locale/en-US.ini
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ translate_add_context="Translate with context"
whisper_translate="Translate to English (Whisper)"
buffer_size_msec="Buffer size (ms)"
overlap_size_msec="Overlap size (ms)"
suppress_sentences="Suppress sentences (each line)"
1 change: 1 addition & 0 deletions src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct transcription_filter_data {
std::string target_lang;
bool buffered_output = false;
bool enable_token_ts_dtw = false;
std::string suppress_sentences;

// Last transcription result
std::string last_text;
Expand Down
7 changes: 6 additions & 1 deletion src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->source_lang = obs_data_get_string(s, "translate_source_language");
gf->target_lang = obs_data_get_string(s, "translate_target_language");
gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context");
gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences");

if (new_translate != gf->translate) {
if (new_translate) {
Expand Down Expand Up @@ -610,7 +611,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
send_caption_to_source(text, gf);
}
},
20,
30,
std::chrono::seconds(10));

obs_log(gf->log_level, "run update");
Expand Down Expand Up @@ -723,6 +724,7 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_string(s, "translate_target_language", "__es__");
obs_data_set_default_string(s, "translate_source_language", "__en__");
obs_data_set_default_bool(s, "translate_add_context", true);
obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT);

// Whisper parameters
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
Expand Down Expand Up @@ -922,6 +924,9 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_list_add_int(list, "INFO", LOG_INFO);
obs_property_list_add_int(list, "WARNING", LOG_WARNING);

// add a text input for sentences to suppress
obs_properties_add_text(ppts, "suppress_sentences", MT_("suppress_sentences"), OBS_TEXT_MULTILINE);

obs_properties_t *whisper_params_group = obs_properties_create();
obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"),
OBS_GROUP_NORMAL, whisper_params_group);
Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ const char *const PLUGIN_INFO_TEMPLATE =
"<a href=\"https://github.com/occ-ai\">OCC AI</a> ❤️ "
"<a href=\"https://www.patreon.com/RoyShilkrot\">Support & Follow</a>";

const char* const SUPPRESS_SENTENCES_DEFAULT = "Thank you for watching\nThank you";

#ifdef __cplusplus
}
#endif
18 changes: 14 additions & 4 deletions src/whisper-utils/token-buffer-thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void TokenBufferThread::addWords(const std::vector<whisper_token_data> &words)
log_token_vector(words);

// run reconstructSentence
std::vector<whisper_token_data> reconstructed = reconstructSentence(currentWords, words);
std::vector<whisper_token_data> reconstructed = reconstructSentence(currentWords, words);

log_token_vector(reconstructed);

Expand All @@ -56,6 +56,8 @@ void TokenBufferThread::addWords(const std::vector<whisper_token_data> &words)
for (const auto &word : reconstructed) {
wordQueue.push_back(word);
}

newDataAvailable = true;
}
condVar.notify_all();
}
Expand Down Expand Up @@ -102,11 +104,19 @@ void TokenBufferThread::monitor()
this->wordQueue.push_front(*it);
}

if (this->wordQueue.size() >= this->maxSize ||
std::chrono::steady_clock::now() - startTime >= this->maxTime) {
// check if we need to flush the queue
auto elapsedTime = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - startTime);
if (this->wordQueue.size() >= this->maxSize || elapsedTime >= this->maxTime) {
// flush the queue if it's full or we've reached the max time
size_t words_to_flush =
std::min(this->wordQueue.size(), this->maxSize);
// make sure we leave at least 3 words in the queue
size_t words_remaining = this->wordQueue.size() - words_to_flush;
if (words_remaining < 3) {
words_to_flush -= 3 - words_remaining;
}
obs_log(LOG_INFO, "TokenBufferThread::monitor: flushing %d words", words_to_flush);
for (size_t i = 0; i < words_to_flush; ++i) {
wordQueue.pop_front();
}
Expand All @@ -116,4 +126,4 @@ void TokenBufferThread::monitor()
newDataAvailable = false;
}
obs_log(LOG_INFO, "TokenBufferThread::monitor: done");
}
}
25 changes: 20 additions & 5 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,12 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
std::string(token_str) + "), ";
tokens.push_back(token);
}
obs_log(LOG_INFO, "Token %d: %d, %s, p: %.3f, dtw: %ld [keep: %d]", j,
obs_log(gf->log_level, "Token %d: %d, %s, p: %.3f, dtw: %ld [keep: %d]", j,
token.id, token_str, token.p, token.t_dtw, keep);
}
sentence_p /= (float)n_tokens;
obs_log(LOG_INFO, "Decoded sentence: '%s'", text.c_str());
obs_log(LOG_INFO, "Token IDs: %s", tokenIds.c_str());
obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str());
obs_log(gf->log_level, "Token IDs: %s", tokenIds.c_str());

// convert text to lowercase
std::string text_lower(text);
Expand All @@ -312,6 +312,21 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
.base(),
text_lower.end());

// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
std::string suppress_sentences_copy = gf->suppress_sentences;
size_t pos = 0;
std::string token;
while ((pos = suppress_sentences_copy.find("\n")) != std::string::npos) {
token = suppress_sentences_copy.substr(0, pos);
suppress_sentences_copy.erase(0, pos + 1);
if (text_lower == suppress_sentences_copy) {
obs_log(gf->log_level, "Suppressing sentence: %s", text_lower.c_str());
return {DETECTION_RESULT_SUPPRESSED, "", 0, 0, {}};
}
}
}

if (gf->log_words) {
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
to_timestamp(t1).c_str(), sentence_p, text_lower.c_str());
Expand Down Expand Up @@ -456,8 +471,8 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)

if (last_step_in_segment) {
const uint64_t overlap_size_ms = (uint64_t)(gf->overlap_frames * 1000 / gf->sample_rate);
obs_log(gf->log_level,
"copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning",
obs_log(gf->log_level,
"copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning",
gf->overlap_frames, overlap_size_ms, gf->last_num_frames - gf->overlap_frames);
for (size_t c = 0; c < gf->channels; c++) {
// This is the last step in the segment - reset the copy buffer (include overlap frames)
Expand Down
1 change: 1 addition & 0 deletions src/whisper-utils/whisper-processing.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum DetectionResult {
DETECTION_RESULT_UNKNOWN = 0,
DETECTION_RESULT_SILENCE = 1,
DETECTION_RESULT_SPEECH = 2,
DETECTION_RESULT_SUPPRESSED = 3,
};

struct DetectionResultWithText {
Expand Down
29 changes: 25 additions & 4 deletions src/whisper-utils/whisper-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ std::pair<int, int> findStartOfOverlap(const std::vector<whisper_token_data> &se
if (seq1.empty() || seq2.empty() || seq1.size() == 1 || seq2.size() == 1) {
return {-1, -1};
}
for (int i = 0; i < seq1.size() - 1; ++i) {
for (int i = seq1.size() - 2; i >= seq1.size() / 2; --i) {

Check failure on line 167 in src/whisper-utils/whisper-utils.cpp

View workflow job for this annotation

GitHub Actions / Build Project 🧱 / Build for macOS 🍏 (x86_64)

comparison of integers of different signs: 'int' and 'size_type' (aka 'unsigned long') [-Werror,-Wsign-compare]
for (int j = 0; j < seq2.size() - 1; ++j) {

Check failure on line 168 in src/whisper-utils/whisper-utils.cpp

View workflow job for this annotation

GitHub Actions / Build Project 🧱 / Build for macOS 🍏 (x86_64)

comparison of integers of different signs: 'int' and 'size_type' (aka 'unsigned long') [-Werror,-Wsign-compare]
if (seq1[i].id == seq2[j].id) {
// Check if the next token in both sequences is the same
Expand Down Expand Up @@ -194,17 +194,38 @@ std::vector<whisper_token_data> reconstructSentence(const std::vector<whisper_to
std::vector<whisper_token_data> reconstructed;

if (overlap.first == -1 || overlap.second == -1) {
if (seq1.empty() && seq2.empty()) {
return reconstructed;
}
if (seq1.empty()) {
return seq2;
}
if (seq2.empty()) {
return seq1;
}

// Return concat of seq1 and seq2 if no overlap found

// check if the last token of seq1 == the first token of seq2
if (!seq1.empty() && !seq2.empty() && seq1.back().id == seq2.front().id) {
if (seq1.back().id == seq2.front().id) {
// don't add the last token of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1);
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
} else if (seq2.size() > 1 && seq1.back().id == seq2[1].id) {
// check if the last token of seq1 == the second token of seq2
// don't add the last token of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1);
// don't add the first token of seq2
reconstructed.insert(reconstructed.end(), seq2.begin() + 1, seq2.end());
} else if (seq1.size() > 1 && seq1[seq1.size() - 2].id == seq2.front().id) {
// check if the second to last token of seq1 == the first token of seq2
// don't add the last two tokens of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 2);
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
} else {
// add all tokens of seq1
reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end());
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
}
reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end());
return reconstructed;
}

Expand Down

0 comments on commit a4c84ae

Please sign in to comment.