From 5f48c3dd7ba76ec4e95bfb473eb356930b1ed519 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 3 Jun 2024 13:29:12 -0400 Subject: [PATCH 01/12] refactor: Update whispercpp dependency to version 0.0.3 --- CMakeLists.txt | 8 +- data/locale/en-US.ini | 6 + src/model-utils/model-infos.cpp | 54 +++ src/tests/localvocal-offline-test.cpp | 30 +- src/transcription-filter-callbacks.cpp | 228 +++++++++ src/transcription-filter-callbacks.h | 20 + src/transcription-filter-data.h | 7 +- src/transcription-filter-utils.cpp | 55 +++ src/transcription-filter-utils.h | 33 ++ src/transcription-filter.cpp | 566 ++++++++-------------- src/transcription-utils.cpp | 20 + src/transcription-utils.h | 34 +- src/translation/translation-includes.h | 8 + src/translation/translation-utils.cpp | 7 +- src/translation/translation-utils.h | 4 + src/translation/translation.cpp | 10 +- src/translation/translation.h | 17 +- src/utils.cpp | 21 - src/utils.h | 9 - src/whisper-utils/token-buffer-thread.cpp | 254 ++++++---- src/whisper-utils/token-buffer-thread.h | 36 +- src/whisper-utils/whisper-model-utils.cpp | 14 +- src/whisper-utils/whisper-processing.cpp | 205 +------- 23 files changed, 905 insertions(+), 741 deletions(-) create mode 100644 src/transcription-filter-callbacks.cpp create mode 100644 src/transcription-filter-callbacks.h create mode 100644 src/transcription-filter-utils.cpp create mode 100644 src/transcription-filter-utils.h create mode 100644 src/translation/translation-includes.h delete mode 100644 src/utils.cpp delete mode 100644 src/utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ac977f..8b58b07 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,6 +87,8 @@ target_sources( PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c + src/transcription-filter-callbacks.cpp + src/transcription-filter-utils.cpp src/transcription-utils.cpp src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp @@ -99,8 +101,7 @@ target_sources( src/whisper-utils/token-buffer-thread.cpp src/translation/language_codes.cpp src/translation/translation.cpp - src/translation/translation-utils.cpp - src/utils.cpp) + src/translation/translation-utils.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) @@ -121,8 +122,7 @@ if(ENABLE_TESTS) src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp src/translation/language_codes.cpp - src/translation/translation.cpp - src/utils.cpp) + src/translation/translation.cpp) find_libav(${CMAKE_PROJECT_NAME}-tests) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index ad1d07b..c9b8757 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -61,3 +61,9 @@ sentence_psum_accept_thresh="Sentence prob. threshold" external_model_folder="External model folder" load_external_model="Load external model" translate_input_tokenization_style="Input token style" +translation_sampling_temperature="Sampling temperature" +translation_repetition_penalty="Repetition penalty" +translation_beam_size="Beam size" +translation_max_decoding_length="Max decoding length" +translation_no_repeat_ngram_size="No-repeat ngram size" +translation_max_input_length="Max input length" diff --git a/src/model-utils/model-infos.cpp b/src/model-utils/model-infos.cpp index 1c21aa1..4f7b19a 100644 --- a/src/model-utils/model-infos.cpp +++ b/src/model-utils/model-infos.cpp @@ -45,6 +45,60 @@ std::map models_info = {{ "B6E77E474AEEA8F441363ACA7614317C06381F3EACFE10FB9856D5081D1074CC"}, {"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", "D8F7C76ED2A5E0822BE39F0A4F95A55EB19C78F4593CE609E2EDBC2AEA4D380A"}}}}, + {"NLLB 200 1.3B (1.4Gb)", + {"NLLB 200 1.3B", + "nllb-200-1.3b", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/model.bin?download=true", + "72D7533DC7A0E8F10F19A650D4E90FAF9CBFA899DB5411AD124BD5802BD91263"}, + { + "https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/config.json?download=true", + "0C2F6FA2057C7264D052FB4A62BA3476EEAE70487ACDDFA8E779A53A00CBF44C", + }, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/tokenizer.json?download=true", + "E316B82DE11D0F951F370943B3C438311629547285129B0B81DADABD01BCA665"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/shared_vocabulary.txt?download=true", + "A132A83330F45514C2476EB81D1D69B3C41762264D16CE0A7EA982E5D6C728E5"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "992BD4ED610D644D6823081937BCC91BB8878DD556CEA4AE5327F2480361330E"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/tokenizer_config.json?download=true", + "D1AA8C3697D3E35674F97B5B7E9C99D22B010F528E80140257D97316BE90D044"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "14BB8DFB35C0FFDEA7BC01E56CEA38B9E3D5EFCDCB9C251D6B40538E1AAB555A"}}}}, + {"NLLB 200 600M (650Mb)", + {"NLLB 200 600M", + "nllb-200-600m", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/model.bin?download=true", + "ED1BEAF75134DE7505315A5223162F56ACFF397EFF6B50638A500D3936FE707B"}, + { + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/config.json?download=true", + "0C2F6FA2057C7264D052FB4A62BA3476EEAE70487ACDDFA8E779A53A00CBF44C", + }, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/tokenizer.json?download=true", + "E316B82DE11D0F951F370943B3C438311629547285129B0B81DADABD01BCA665"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/shared_vocabulary.txt?download=true", + "A132A83330F45514C2476EB81D1D69B3C41762264D16CE0A7EA982E5D6C728E5"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "992BD4ED610D644D6823081937BCC91BB8878DD556CEA4AE5327F2480361330E"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/tokenizer_config.json?download=true", + "D1AA8C3697D3E35674F97B5B7E9C99D22B010F528E80140257D97316BE90D044"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "14BB8DFB35C0FFDEA7BC01E56CEA38B9E3D5EFCDCB9C251D6B40538E1AAB555A"}}}}, + {"MADLAD 400 3B (2.9Gb)", + {"MADLAD 400 3B", + "madlad-400-3b", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/model.bin?download=true", + "F3C87256A2C888100C179D7DCD7F41DF17C767469546C59D32C7DDE86C740A6B"}, + { + "https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/config.json?download=true", + "A428C51CD35517554523B3C6B6974A5928BC35E82B130869A543566A34A83B93", + }, + {"https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/shared_vocabulary.txt?download=true", + "C327551CE3CA6EFC7B437E11A267F79979893332DDA8A1D146E2C950815193F8"}, + {"https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/sentencepiece.model?download=true", + "EF11AC9A22C7503492F56D48DCE53BE20E339B63605983E9F27D2CD0E0F3922C"}}}}, {"Whisper Base q5 (57Mb)", {"Whisper Base q5", "whisper-base-q5", diff --git a/src/tests/localvocal-offline-test.cpp b/src/tests/localvocal-offline-test.cpp index e87f571..40b0033 100644 --- a/src/tests/localvocal-offline-test.cpp +++ b/src/tests/localvocal-offline-test.cpp @@ -13,6 +13,7 @@ #include #include "transcription-filter-data.h" +#include "transcription-filter-utils.h" #include "transcription-filter.h" #include "transcription-utils.h" #include "whisper-utils/whisper-utils.h" @@ -84,7 +85,6 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p gf->sample_rate = sample_rate; gf->frames = (size_t)((float)gf->sample_rate * 10.0f); gf->last_num_frames = 0; - gf->step_size_msec = 3000; gf->min_sub_duration = 3000; gf->last_sub_render_time = 0; gf->save_srt = false; @@ -110,8 +110,6 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p memset(gf->copy_buffers[0], 0, gf->channels * gf->frames * sizeof(float)); obs_log(LOG_INFO, " allocated %llu bytes ", gf->channels * gf->frames * sizeof(float)); - gf->overlap_ms = 150; - gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms)); obs_log(gf->log_level, "channels %d, frames %d, sample_rate %d", (int)gf->channels, (int)gf->frames, gf->sample_rate); @@ -158,11 +156,12 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p gf->whisper_params = whisper_full_default_params(whisper_sampling_method); gf->whisper_params.duration_ms = 3000; gf->whisper_params.language = "en"; + gf->whisper_params.detect_language = false; gf->whisper_params.initial_prompt = ""; gf->whisper_params.n_threads = 4; gf->whisper_params.n_max_text_ctx = 16384; gf->whisper_params.translate = false; - gf->whisper_params.no_context = true; + gf->whisper_params.no_context = false; gf->whisper_params.single_segment = true; gf->whisper_params.print_special = false; gf->whisper_params.print_progress = false; @@ -177,7 +176,7 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p gf->whisper_params.speed_up = false; gf->whisper_params.suppress_blank = true; gf->whisper_params.suppress_non_speech_tokens = true; - gf->whisper_params.temperature = 0.1; + gf->whisper_params.temperature = 0.0; gf->whisper_params.max_initial_ts = 1.0; gf->whisper_params.length_penalty = -1; gf->active = true; @@ -204,7 +203,7 @@ void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm // numeral = "0" + numeral; // } - // save the audio to a .wav file + // // save the audio to a .wav file // std::string filename = "audio_chunk_" + numeral + vad_state_str + ".wav"; // obs_log(gf->log_level, "Saving %lu frames to %s", frames, filename.c_str()); // write_audio_wav_file(filename.c_str(), pcm32f_data, frames); @@ -281,7 +280,7 @@ void set_text_callback(struct transcription_filter_data *gf, str_copy.c_str(), translated_text.c_str()); } // overwrite the original text with the translated text - str_copy = str_copy + " -> " + translated_text; + str_copy = str_copy + " | " + translated_text; } else { obs_log(gf->log_level, "Failed to translate text"); } @@ -385,19 +384,22 @@ int wmain(int argc, wchar_t *argv[]) gf->suppress_sentences = config["suppress_sentences"].get(); } - if (config.contains("overlap_ms")) { - obs_log(LOG_INFO, "Setting overlap_ms to %d", - config["overlap_ms"].get()); - gf->overlap_ms = config["overlap_ms"]; - gf->overlap_frames = (size_t)((float)gf->sample_rate / - (1000.0f / (float)gf->overlap_ms)); - } if (config.contains("enable_audio_chunks_callback")) { obs_log(LOG_INFO, "Setting enable_audio_chunks_callback to %s", config["enable_audio_chunks_callback"] ? "true" : "false"); gf->enable_audio_chunks_callback = config["enable_audio_chunks_callback"]; } + if (config.contains("temperature")) { + obs_log(LOG_INFO, "Setting temperture to %f", + config["temperature"].get()); + gf->whisper_params.temperature = config["temperature"].get(); + } + if (config.contains("no_context")) { + obs_log(LOG_INFO, "Setting no_context to %s", + config["no_context"] ? "true" : "false"); + gf->whisper_params.no_context = config["no_context"]; + } // set log level if (logLevelStr == "debug") { gf->log_level = LOG_DEBUG; diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp new file mode 100644 index 0000000..6d0e062 --- /dev/null +++ b/src/transcription-filter-callbacks.cpp @@ -0,0 +1,228 @@ +#ifdef _WIN32 +#define NOMINMAX +#endif + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "transcription-filter-callbacks.h" +#include "transcription-utils.h" +#include "translation/translation.h" +#include "translation/translation-includes.h" + +#define SEND_TIMED_METADATA_URL "http://localhost:8080/timed-metadata" + +void send_caption_to_source(const std::string &target_source_name, const std::string &caption, + struct transcription_filter_data *gf) +{ + if (target_source_name.empty()) { + return; + } + auto target = obs_get_source_by_name(target_source_name.c_str()); + if (!target) { + obs_log(gf->log_level, "text_source target is null"); + return; + } + auto text_settings = obs_source_get_settings(target); + obs_data_set_string(text_settings, "text", caption.c_str()); + obs_source_update(target, text_settings); + obs_source_release(target); +} + +void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data, + size_t frames, int vad_state, const DetectionResultWithText &result) +{ + UNUSED_PARAMETER(gf); + UNUSED_PARAMETER(pcm32f_data); + UNUSED_PARAMETER(frames); + UNUSED_PARAMETER(vad_state); + UNUSED_PARAMETER(result); + // stub +} + +void set_text_callback(struct transcription_filter_data *gf, + const DetectionResultWithText &resultIn) +{ + DetectionResultWithText result = resultIn; + uint64_t now = now_ms(); + if (result.text.empty() || result.result != DETECTION_RESULT_SPEECH) { + // check if we should clear the current sub depending on the minimum subtitle duration + if ((now - gf->last_sub_render_time) > gf->min_sub_duration) { + // clear the current sub, run an empty sub + result.text = ""; + } else { + // nothing to do, the incoming sub is empty + return; + } + } + gf->last_sub_render_time = now; + + std::string str_copy = result.text; + + // recondition the text - only if the output is not English + if (gf->whisper_params.language != nullptr && + strcmp(gf->whisper_params.language, "en") != 0) { + str_copy = fix_utf8(str_copy); + } else { + // only remove leading and trailing non-alphanumeric characters if the output is English + str_copy = remove_leading_trailing_nonalpha(str_copy); + } + + // if suppression is enabled, check if the text is in the suppression list + if (!gf->suppress_sentences.empty()) { + // split the suppression list by newline into individual sentences + std::vector suppress_sentences_list = + split(gf->suppress_sentences, '\n'); + const std::string original_str_copy = str_copy; + // check if the text is in the suppression list + for (const std::string &suppress_sentence : suppress_sentences_list) { + // if suppress_sentence exists within str_copy, remove it (replace with "") + str_copy = std::regex_replace(str_copy, std::regex(suppress_sentence), ""); + } + // if the text was modified, log the original and modified text + if (original_str_copy != str_copy) { + obs_log(gf->log_level, "------ Suppressed text: '%s' -> '%s'", + original_str_copy.c_str(), str_copy.c_str()); + } + if (remove_leading_trailing_nonalpha(str_copy).empty()) { + // if the text is empty after suppression, return + return; + } + } + + if (gf->translate && !str_copy.empty() && str_copy != gf->last_text && + result.result == DETECTION_RESULT_SPEECH) { + obs_log(gf->log_level, "Translating text. %s -> %s", gf->source_lang.c_str(), + gf->target_lang.c_str()); + std::string translated_text; + if (translate(gf->translation_ctx, str_copy, gf->source_lang, gf->target_lang, + translated_text) == OBS_POLYGLOT_TRANSLATION_SUCCESS) { + if (gf->log_words) { + obs_log(LOG_INFO, "Translation: '%s' -> '%s'", str_copy.c_str(), + translated_text.c_str()); + } + if (gf->translation_output == "none") { + // overwrite the original text with the translated text + str_copy = translated_text; + } else { + // send the translation to the selected source + send_caption_to_source(gf->translation_output, translated_text, gf); + } + } else { + obs_log(gf->log_level, "Failed to translate text"); + } + } + + gf->last_text = str_copy; + + if (gf->buffered_output) { + gf->captions_monitor.addSentence(str_copy); + } + + if (gf->caption_to_stream) { + obs_output_t *streaming_output = obs_frontend_get_streaming_output(); + if (streaming_output) { + obs_output_output_caption_text1(streaming_output, str_copy.c_str()); + obs_output_release(streaming_output); + } + } + + if (gf->send_timed_metadata) { + send_timed_metadata(gf, result); + } + + if (gf->output_file_path != "" && gf->text_source_name.empty()) { + // Check if we should save the sentence + if (gf->save_only_while_recording && !obs_frontend_recording_active()) { + // We are not recording, do not save the sentence to file + return; + } + // should the file be truncated? + std::ios_base::openmode openmode = std::ios::out; + if (gf->truncate_output_file) { + openmode |= std::ios::trunc; + } else { + openmode |= std::ios::app; + } + if (!gf->save_srt) { + // Write raw sentence to file + std::ofstream output_file(gf->output_file_path, openmode); + output_file << str_copy << std::endl; + output_file.close(); + } else { + obs_log(gf->log_level, "Saving sentence to file %s, sentence #%d", + gf->output_file_path.c_str(), gf->sentence_number); + // Append sentence to file in .srt format + std::ofstream output_file(gf->output_file_path, openmode); + output_file << gf->sentence_number << std::endl; + // use the start and end timestamps to calculate the start and end time in srt format + auto format_ts_for_srt = [&output_file](uint64_t ts) { + uint64_t time_s = ts / 1000; + uint64_t time_m = time_s / 60; + uint64_t time_h = time_m / 60; + uint64_t time_ms_rem = ts % 1000; + uint64_t time_s_rem = time_s % 60; + uint64_t time_m_rem = time_m % 60; + uint64_t time_h_rem = time_h % 60; + output_file << std::setfill('0') << std::setw(2) << time_h_rem + << ":" << std::setfill('0') << std::setw(2) + << time_m_rem << ":" << std::setfill('0') + << std::setw(2) << time_s_rem << "," + << std::setfill('0') << std::setw(3) << time_ms_rem; + }; + format_ts_for_srt(result.start_timestamp_ms); + output_file << " --> "; + format_ts_for_srt(result.end_timestamp_ms); + output_file << std::endl; + + output_file << str_copy << std::endl; + output_file << std::endl; + output_file.close(); + gf->sentence_number++; + } + } else { + if (!gf->buffered_output) { + // Send the caption to the text source + send_caption_to_source(gf->text_source_name, str_copy, gf); + } + } +}; + +void recording_state_callback(enum obs_frontend_event event, void *data) +{ + struct transcription_filter_data *gf_ = + static_cast(data); + if (event == OBS_FRONTEND_EVENT_RECORDING_STARTING) { + if (gf_->save_srt && gf_->save_only_while_recording) { + obs_log(gf_->log_level, "Recording started. Resetting srt file."); + // truncate file if it exists + std::ofstream output_file(gf_->output_file_path, + std::ios::out | std::ios::trunc); + output_file.close(); + gf_->sentence_number = 1; + gf_->start_timestamp_ms = now_ms(); + } + } else if (event == OBS_FRONTEND_EVENT_RECORDING_STOPPED) { + if (gf_->save_srt && gf_->save_only_while_recording && + gf_->rename_file_to_match_recording) { + obs_log(gf_->log_level, "Recording stopped. Rename srt file."); + // rename file to match the recording file name with .srt extension + // use obs_frontend_get_last_recording to get the last recording file name + std::string recording_file_name = obs_frontend_get_last_recording(); + // remove the extension + recording_file_name = recording_file_name.substr( + 0, recording_file_name.find_last_of(".")); + std::string srt_file_name = recording_file_name + ".srt"; + // rename the file + std::rename(gf_->output_file_path.c_str(), srt_file_name.c_str()); + } + } +} diff --git a/src/transcription-filter-callbacks.h b/src/transcription-filter-callbacks.h new file mode 100644 index 0000000..481af9f --- /dev/null +++ b/src/transcription-filter-callbacks.h @@ -0,0 +1,20 @@ +#ifndef TRANSCRIPTION_FILTER_CALLBACKS_H +#define TRANSCRIPTION_FILTER_CALLBACKS_H + +#include + +#include "transcription-filter-data.h" +#include "whisper-utils/whisper-processing.h" + +void send_caption_to_source(const std::string &target_source_name, const std::string &str_copy, + struct transcription_filter_data *gf); + +void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data, + size_t frames, int vad_state, const DetectionResultWithText &result); + +void set_text_callback(struct transcription_filter_data *gf, + const DetectionResultWithText &resultIn); + +void recording_state_callback(enum obs_frontend_event event, void *data); + +#endif /* TRANSCRIPTION_FILTER_CALLBACKS_H */ diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 93ea4ca..dc56b1a 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -15,6 +15,7 @@ #include #include "translation/translation.h" +#include "translation/translation-includes.h" #include "whisper-utils/silero-vad-onnx.h" #include "whisper-utils/whisper-processing.h" #include "whisper-utils/token-buffer-thread.h" @@ -27,13 +28,8 @@ struct transcription_filter_data { uint32_t sample_rate; // input sample rate // How many input frames (in input sample rate) are needed for the next whisper frame size_t frames; - // How many ms/frames are needed to overlap with the next whisper frame - size_t overlap_frames; - size_t overlap_ms; // How many frames were processed in the last whisper frame (this is dynamic) size_t last_num_frames; - // Milliseconds per processing step (e.g. rest of the whisper buffer may be filled with silence) - size_t step_size_msec; // Start begining timestamp in ms since epoch uint64_t start_timestamp_ms; // Sentence counter for srt @@ -83,6 +79,7 @@ struct transcription_filter_data { std::string suppress_sentences; bool fix_utf8 = true; bool enable_audio_chunks_callback = false; + bool source_signals_set = false; // Last transcription result std::string last_text; diff --git a/src/transcription-filter-utils.cpp b/src/transcription-filter-utils.cpp new file mode 100644 index 0000000..72f313c --- /dev/null +++ b/src/transcription-filter-utils.cpp @@ -0,0 +1,55 @@ +#include "transcription-filter-utils.h" + +#include +#include +#include + +void create_obs_text_source() +{ + // create a new OBS text source called "LocalVocal Subtitles" + obs_source_t *scene_as_source = obs_frontend_get_current_scene(); + obs_scene_t *scene = obs_scene_from_source(scene_as_source); +#ifdef _WIN32 + obs_source_t *source = + obs_source_create("text_gdiplus_v2", "LocalVocal Subtitles", nullptr, nullptr); +#else + obs_source_t *source = + obs_source_create("text_ft2_source_v2", "LocalVocal Subtitles", nullptr, nullptr); +#endif + if (source) { + // add source to the current scene + obs_scene_add(scene, source); + // set source settings + obs_data_t *source_settings = obs_source_get_settings(source); + obs_data_set_bool(source_settings, "word_wrap", true); + obs_data_set_int(source_settings, "custom_width", 1760); + obs_data_t *font_data = obs_data_create(); + obs_data_set_string(font_data, "face", "Arial"); + obs_data_set_string(font_data, "style", "Regular"); + obs_data_set_int(font_data, "size", 72); + obs_data_set_int(font_data, "flags", 0); + obs_data_set_obj(source_settings, "font", font_data); + obs_data_release(font_data); + obs_source_update(source, source_settings); + obs_data_release(source_settings); + + // set transform settings + obs_transform_info transform_info; + transform_info.pos.x = 962.0; + transform_info.pos.y = 959.0; + transform_info.bounds.x = 1769.0; + transform_info.bounds.y = 145.0; + transform_info.bounds_type = obs_bounds_type::OBS_BOUNDS_SCALE_INNER; + transform_info.bounds_alignment = OBS_ALIGN_CENTER; + transform_info.alignment = OBS_ALIGN_CENTER; + transform_info.scale.x = 1.0; + transform_info.scale.y = 1.0; + transform_info.rot = 0.0; + obs_sceneitem_t *source_sceneitem = obs_scene_sceneitem_from_source(scene, source); + obs_sceneitem_set_info(source_sceneitem, &transform_info); + obs_sceneitem_release(source_sceneitem); + + obs_source_release(source); + } + obs_source_release(scene_as_source); +} diff --git a/src/transcription-filter-utils.h b/src/transcription-filter-utils.h new file mode 100644 index 0000000..9f24d55 --- /dev/null +++ b/src/transcription-filter-utils.h @@ -0,0 +1,33 @@ +#ifndef TRANSCRIPTION_FILTER_UTILS_H +#define TRANSCRIPTION_FILTER_UTILS_H + +#include + +// Convert channels number to a speaker layout +inline enum speaker_layout convert_speaker_layout(uint8_t channels) +{ + switch (channels) { + case 0: + return SPEAKERS_UNKNOWN; + case 1: + return SPEAKERS_MONO; + case 2: + return SPEAKERS_STEREO; + case 3: + return SPEAKERS_2POINT1; + case 4: + return SPEAKERS_4POINT0; + case 5: + return SPEAKERS_4POINT1; + case 6: + return SPEAKERS_5POINT1; + case 8: + return SPEAKERS_7POINT1; + default: + return SPEAKERS_UNKNOWN; + } +} + +void create_obs_text_source(); + +#endif // TRANSCRIPTION_FILTER_UTILS_H diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 08626e6..1f8025f 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -1,9 +1,24 @@ #include #include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#define NOMINMAX +#include +#endif + +#include + #include "plugin-support.h" #include "transcription-filter.h" +#include "transcription-filter-callbacks.h" #include "transcription-filter-data.h" +#include "transcription-filter-utils.h" #include "transcription-utils.h" #include "model-utils/model-downloader.h" #include "whisper-utils/whisper-processing.h" @@ -13,19 +28,7 @@ #include "translation/language_codes.h" #include "translation/translation-utils.h" #include "translation/translation.h" -#include "utils.h" - -#include -#include -#include -#include -#include -#include -#ifdef _WIN32 -#include -#endif - -#include +#include "translation/translation-includes.h" bool add_sources_to_list(void *list_property, obs_source_t *source) { @@ -41,6 +44,83 @@ bool add_sources_to_list(void *list_property, obs_source_t *source) return true; } +void reset_caption_state(transcription_filter_data *gf_) +{ + if (gf_->captions_monitor.isEnabled()) { + gf_->captions_monitor.clear(); + } + send_caption_to_source(gf_->text_source_name, "", gf_); + // flush the buffer + { + std::lock_guard lock(gf_->whisper_buf_mutex); + for (size_t c = 0; c < gf_->channels; c++) { + if (gf_->input_buffers[c].data != nullptr) { + circlebuf_free(&gf_->input_buffers[c]); + } + } + if (gf_->info_buffer.data != nullptr) { + circlebuf_free(&gf_->info_buffer); + } + if (gf_->whisper_buffer.data != nullptr) { + circlebuf_free(&gf_->whisper_buffer); + } + } +} + +void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source) +{ + signal_handler_t *sh = obs_source_get_signal_handler(parent_source); + signal_handler_connect( + sh, "media_play", + [](void *data_, calldata_t *cd) { + transcription_filter_data *gf_ = + static_cast(data_); + obs_log(gf_->log_level, "media_play"); + gf_->active = true; + }, + gf); + signal_handler_connect( + sh, "media_started", + [](void *data_, calldata_t *cd) { + transcription_filter_data *gf_ = + static_cast(data_); + obs_log(gf_->log_level, "media_started"); + gf_->active = true; + reset_caption_state(gf_); + }, + gf); + signal_handler_connect( + sh, "media_pause", + [](void *data_, calldata_t *cd) { + transcription_filter_data *gf_ = + static_cast(data_); + obs_log(gf_->log_level, "media_pause"); + gf_->active = false; + }, + gf); + signal_handler_connect( + sh, "media_restart", + [](void *data_, calldata_t *cd) { + transcription_filter_data *gf_ = + static_cast(data_); + obs_log(gf_->log_level, "media_restart"); + gf_->active = true; + reset_caption_state(gf_); + }, + gf); + signal_handler_connect( + sh, "media_stopped", + [](void *data_, calldata_t *cd) { + transcription_filter_data *gf_ = + static_cast(data_); + obs_log(gf_->log_level, "media_stopped"); + gf_->active = false; + reset_caption_state(gf_); + }, + gf); + gf->source_signals_set = true; +} + struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_audio_data *audio) { if (!audio) { @@ -53,14 +133,16 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ struct transcription_filter_data *gf = static_cast(data); - if (!gf->active) { - return audio; + // Lazy initialization of source signals + if (!gf->source_signals_set) { + // obs_filter_get_parent only works in the filter function + obs_source_t *parent_source = obs_filter_get_parent(gf->context); + if (parent_source != nullptr) { + set_source_signals(gf, parent_source); + } } - // Check if the parent source is muted - obs_source_t *parent_source = obs_filter_get_parent(gf->context); - if (gf->process_while_muted == false && obs_source_muted(parent_source)) { - // Source is muted, do not process audio + if (!gf->active) { return audio; } @@ -69,6 +151,17 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ return audio; } + // Check if process while muted is not enabled (e.g. the user wants to avoid processing audio + // when the source is muted) + if (!gf->process_while_muted) { + // Check if the parent source is muted + obs_source_t *parent_source = obs_filter_get_parent(gf->context); + if (parent_source != nullptr && obs_source_muted(parent_source)) { + // Source is muted, do not process audio + return audio; + } + } + { std::lock_guard lock(gf->whisper_buf_mutex); // scoped lock // push back current audio data to input circlebuf @@ -117,191 +210,16 @@ void transcription_filter_destroy(void *data) bfree(gf); } -void send_caption_to_source(const std::string &target_source_name, const std::string &str_copy, - struct transcription_filter_data *gf) -{ - auto target = obs_get_source_by_name(target_source_name.c_str()); - if (!target) { - obs_log(gf->log_level, "text_source target is null"); - return; - } - auto text_settings = obs_source_get_settings(target); - obs_data_set_string(text_settings, "text", str_copy.c_str()); - obs_source_update(target, text_settings); - obs_source_release(target); -} - -void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data, - size_t frames, int vad_state, const DetectionResultWithText &result) -{ - UNUSED_PARAMETER(gf); - UNUSED_PARAMETER(pcm32f_data); - UNUSED_PARAMETER(frames); - UNUSED_PARAMETER(vad_state); - UNUSED_PARAMETER(result); - // stub -} - -void set_text_callback(struct transcription_filter_data *gf, - const DetectionResultWithText &resultIn) -{ - DetectionResultWithText result = resultIn; - uint64_t now = now_ms(); - if (result.text.empty() || result.result != DETECTION_RESULT_SPEECH) { - // check if we should clear the current sub depending on the minimum subtitle duration - if ((now - gf->last_sub_render_time) > gf->min_sub_duration) { - // clear the current sub, run an empty sub - result.text = ""; - } else { - // nothing to do, the incoming sub is empty - return; - } - } - gf->last_sub_render_time = now; - - std::string str_copy = result.text; - - // recondition the text - only if the output is not English - if (gf->whisper_params.language != nullptr && - strcmp(gf->whisper_params.language, "en") != 0) { - str_copy = fix_utf8(str_copy); - } else { - // only remove leading and trailing non-alphanumeric characters if the output is English - str_copy = remove_leading_trailing_nonalpha(str_copy); - } - - // if suppression is enabled, check if the text is in the suppression list - if (!gf->suppress_sentences.empty()) { - // split the suppression list by newline into individual sentences - std::vector suppress_sentences_list = - split(gf->suppress_sentences, '\n'); - const std::string original_str_copy = str_copy; - // check if the text is in the suppression list - for (const std::string &suppress_sentence : suppress_sentences_list) { - // if suppress_sentence exists within str_copy, remove it (replace with "") - str_copy = std::regex_replace(str_copy, std::regex(suppress_sentence), ""); - } - // if the text was modified, log the original and modified text - if (original_str_copy != str_copy) { - obs_log(gf->log_level, "------ Suppressed text: '%s' -> '%s'", - original_str_copy.c_str(), str_copy.c_str()); - } - if (remove_leading_trailing_nonalpha(str_copy).empty()) { - // if the text is empty after suppression, return - return; - } - } - - if (gf->translate && !str_copy.empty() && str_copy != gf->last_text && - result.result == DETECTION_RESULT_SPEECH) { - obs_log(gf->log_level, "Translating text. %s -> %s", gf->source_lang.c_str(), - gf->target_lang.c_str()); - std::string translated_text; - if (translate(gf->translation_ctx, str_copy, gf->source_lang, gf->target_lang, - translated_text) == OBS_POLYGLOT_TRANSLATION_SUCCESS) { - if (gf->log_words) { - obs_log(LOG_INFO, "Translation: '%s' -> '%s'", str_copy.c_str(), - translated_text.c_str()); - } - if (gf->translation_output == "none") { - // overwrite the original text with the translated text - str_copy = translated_text; - } else { - // send the translation to the selected source - send_caption_to_source(gf->translation_output, translated_text, gf); - } - } else { - obs_log(gf->log_level, "Failed to translate text"); - } - } - - gf->last_text = str_copy; - - if (gf->buffered_output) { - gf->captions_monitor.addWords(result.tokens); - } - - if (gf->caption_to_stream) { - obs_output_t *streaming_output = obs_frontend_get_streaming_output(); - if (streaming_output) { - obs_output_output_caption_text1(streaming_output, str_copy.c_str()); - obs_output_release(streaming_output); - } - } - - if (gf->output_file_path != "" && gf->text_source_name.empty()) { - // Check if we should save the sentence - if (gf->save_only_while_recording && !obs_frontend_recording_active()) { - // We are not recording, do not save the sentence to file - return; - } - // should the file be truncated? - std::ios_base::openmode openmode = std::ios::out; - if (gf->truncate_output_file) { - openmode |= std::ios::trunc; - } else { - openmode |= std::ios::app; - } - if (!gf->save_srt) { - // Write raw sentence to file - std::ofstream output_file(gf->output_file_path, openmode); - output_file << str_copy << std::endl; - output_file.close(); - } else { - obs_log(gf->log_level, "Saving sentence to file %s, sentence #%d", - gf->output_file_path.c_str(), gf->sentence_number); - // Append sentence to file in .srt format - std::ofstream output_file(gf->output_file_path, openmode); - output_file << gf->sentence_number << std::endl; - // use the start and end timestamps to calculate the start and end time in srt format - auto format_ts_for_srt = [&output_file](uint64_t ts) { - uint64_t time_s = ts / 1000; - uint64_t time_m = time_s / 60; - uint64_t time_h = time_m / 60; - uint64_t time_ms_rem = ts % 1000; - uint64_t time_s_rem = time_s % 60; - uint64_t time_m_rem = time_m % 60; - uint64_t time_h_rem = time_h % 60; - output_file << std::setfill('0') << std::setw(2) << time_h_rem - << ":" << std::setfill('0') << std::setw(2) - << time_m_rem << ":" << std::setfill('0') - << std::setw(2) << time_s_rem << "," - << std::setfill('0') << std::setw(3) << time_ms_rem; - }; - format_ts_for_srt(result.start_timestamp_ms); - output_file << " --> "; - format_ts_for_srt(result.end_timestamp_ms); - output_file << std::endl; - - output_file << str_copy << std::endl; - output_file << std::endl; - output_file.close(); - gf->sentence_number++; - } - } else { - if (!gf->buffered_output) { - // Send the caption to the text source - send_caption_to_source(gf->text_source_name, str_copy, gf); - } - } -}; - void transcription_filter_update(void *data, obs_data_t *s) { + obs_log(LOG_INFO, "LocalVocal filter update"); struct transcription_filter_data *gf = static_cast(data); gf->log_level = (int)obs_data_get_int(s, "log_level"); - obs_log(gf->log_level, "filter update"); - gf->vad_enabled = obs_data_get_bool(s, "vad_enabled"); gf->log_words = obs_data_get_bool(s, "log_words"); - gf->frames = (size_t)((float)gf->sample_rate / - (1000.0f / (float)obs_data_get_int(s, "buffer_size_msec"))); gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream"); - bool step_by_step_processing = obs_data_get_bool(s, "step_by_step_processing"); - gf->step_size_msec = step_by_step_processing ? (int)obs_data_get_int(s, "step_size_msec") - : obs_data_get_int(s, "buffer_size_msec"); gf->save_srt = obs_data_get_bool(s, "subtitle_save_srt"); gf->truncate_output_file = obs_data_get_bool(s, "truncate_output_file"); gf->save_only_while_recording = obs_data_get_bool(s, "only_while_recording"); @@ -313,12 +231,32 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration"); gf->last_sub_render_time = 0; bool new_buffered_output = obs_data_get_bool(s, "buffered_output"); - if (new_buffered_output != gf->buffered_output) { - gf->buffered_output = new_buffered_output; - gf->overlap_ms = gf->buffered_output ? MAX_OVERLAP_SIZE_MSEC - : DEFAULT_OVERLAP_SIZE_MSEC; - gf->overlap_frames = - (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms)); + + if (new_buffered_output) { + obs_log(LOG_INFO, "buffered_output enable"); + if (!gf->buffered_output || !gf->captions_monitor.isEnabled()) { + obs_log(LOG_INFO, "buffered_output currently disabled, enabling"); + gf->buffered_output = true; + gf->captions_monitor.initialize( + gf, + [gf](const std::string &text) { + if (gf->buffered_output) { + send_caption_to_source(gf->text_source_name, text, + gf); + } + }, + 2, 30, std::chrono::seconds(10)); + } + } else { + obs_log(LOG_INFO, "buffered_output disable"); + if (gf->buffered_output) { + obs_log(LOG_INFO, "buffered_output currently enabled, disabling"); + if (gf->captions_monitor.isEnabled()) { + gf->captions_monitor.clear(); + gf->captions_monitor.stopThread(); + } + gf->buffered_output = false; + } } bool new_translate = obs_data_get_bool(s, "translate"); @@ -330,13 +268,15 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->translation_output = obs_data_get_string(s, "translate_output"); gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences"); std::string new_translate_model_index = obs_data_get_string(s, "translate_model"); - gf->translation_model_path_external = + std::string new_translation_model_path_external = obs_data_get_string(s, "translation_model_path_external"); if (new_translate != gf->translate || - new_translate_model_index != gf->translation_model_index) { + new_translate_model_index != gf->translation_model_index || + new_translation_model_path_external != gf->translation_model_path_external) { if (new_translate) { gf->translation_model_index = new_translate_model_index; + gf->translation_model_path_external = new_translation_model_path_external; if (gf->translation_model_index != "whisper-based-translation") { start_translation(gf); } else { @@ -349,10 +289,27 @@ void transcription_filter_update(void *data, obs_data_t *s) } } + // translation options + if (gf->translate) { + if (gf->translation_ctx.options) { + gf->translation_ctx.options->sampling_temperature = + (float)obs_data_get_double(s, "translation_sampling_temperature"); + gf->translation_ctx.options->repetition_penalty = + (float)obs_data_get_double(s, "translation_repetition_penalty"); + gf->translation_ctx.options->beam_size = + (int)obs_data_get_int(s, "translation_beam_size"); + gf->translation_ctx.options->max_decoding_length = + (int)obs_data_get_int(s, "translation_max_decoding_length"); + gf->translation_ctx.options->no_repeat_ngram_size = + (int)obs_data_get_int(s, "translation_no_repeat_ngram_size"); + gf->translation_ctx.options->max_input_length = + (int)obs_data_get_int(s, "translation_max_input_length"); + } + } + obs_log(gf->log_level, "update text source"); // update the text source const char *new_text_source_name = obs_data_get_string(s, "subtitle_sources"); - obs_weak_source_t *old_weak_text_source = NULL; if (new_text_source_name == nullptr || strcmp(new_text_source_name, "none") == 0 || strcmp(new_text_source_name, "(null)") == 0 || @@ -369,16 +326,7 @@ void transcription_filter_update(void *data, obs_data_t *s) } } } else { - // new selected text source is valid, check if it's different from the old one - if (gf->text_source_name != new_text_source_name) { - // new text source is different from the old one, release the old one - gf->text_source_name = new_text_source_name; - } - } - - if (old_weak_text_source) { - obs_log(gf->log_level, "releasing old text source"); - obs_weak_source_release(old_weak_text_source); + gf->text_source_name = new_text_source_name; } obs_log(gf->log_level, "update whisper model"); @@ -436,10 +384,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->sample_rate = audio_output_get_sample_rate(obs_get_audio()); gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / MAX_MS_WORK_BUFFER)); gf->last_num_frames = 0; - bool step_by_step_processing = obs_data_get_bool(settings, "step_by_step_processing"); - gf->step_size_msec = step_by_step_processing - ? (int)obs_data_get_int(settings, "step_size_msec") - : obs_data_get_int(settings, "buffer_size_msec"); gf->min_sub_duration = (int)obs_data_get_int(settings, "min_sub_duration"); gf->last_sub_render_time = 0; gf->log_level = (int)obs_data_get_int(settings, "log_level"); @@ -467,8 +411,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->context = filter; - gf->overlap_ms = (int)obs_data_get_int(settings, "overlap_size_msec"); - gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms)); obs_log(gf->log_level, "channels %d, frames %d, sample_rate %d", (int)gf->channels, (int)gf->frames, gf->sample_rate); @@ -496,53 +438,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) obs_source_release(source); } else { // create a new OBS text source called "LocalVocal Subtitles" - obs_source_t *scene_as_source = obs_frontend_get_current_scene(); - obs_scene_t *scene = obs_scene_from_source(scene_as_source); -#ifdef _WIN32 - source = obs_source_create("text_gdiplus_v2", "LocalVocal Subtitles", - nullptr, nullptr); -#else - source = obs_source_create("text_ft2_source_v2", "LocalVocal Subtitles", - nullptr, nullptr); -#endif - if (source) { - // add source to the current scene - obs_scene_add(scene, source); - // set source settings - obs_data_t *source_settings = obs_source_get_settings(source); - obs_data_set_bool(source_settings, "word_wrap", true); - obs_data_set_int(source_settings, "custom_width", 1760); - obs_data_t *font_data = obs_data_create(); - obs_data_set_string(font_data, "face", "Arial"); - obs_data_set_string(font_data, "style", "Regular"); - obs_data_set_int(font_data, "size", 72); - obs_data_set_int(font_data, "flags", 0); - obs_data_set_obj(source_settings, "font", font_data); - obs_data_release(font_data); - obs_source_update(source, source_settings); - obs_data_release(source_settings); - - // set transform settings - obs_transform_info transform_info; - transform_info.pos.x = 962.0; - transform_info.pos.y = 959.0; - transform_info.bounds.x = 1769.0; - transform_info.bounds.y = 145.0; - transform_info.bounds_type = - obs_bounds_type::OBS_BOUNDS_SCALE_INNER; - transform_info.bounds_alignment = OBS_ALIGN_CENTER; - transform_info.alignment = OBS_ALIGN_CENTER; - transform_info.scale.x = 1.0; - transform_info.scale.y = 1.0; - transform_info.rot = 0.0; - obs_sceneitem_t *source_sceneitem = - obs_scene_sceneitem_from_source(scene, source); - obs_sceneitem_set_info(source_sceneitem, &transform_info); - obs_sceneitem_release(source_sceneitem); - - obs_source_release(source); - } - obs_source_release(scene_as_source); + create_obs_text_source(); } gf->text_source_name = "LocalVocal Subtitles"; obs_data_set_string(settings, "subtitle_sources", "LocalVocal Subtitles"); @@ -556,16 +452,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->whisper_model_path = std::string(""); // The update function will set the model path gf->whisper_context = nullptr; - gf->captions_monitor.initialize( - gf, - [gf](const std::string &text) { - obs_log(LOG_INFO, "Captions: %s", text.c_str()); - if (gf->buffered_output) { - send_caption_to_source(gf->text_source_name, text, gf); - } - }, - 30, std::chrono::seconds(10)); - obs_log(gf->log_level, "run update"); // get the settings updated on the filter data struct transcription_filter_update(gf, settings); @@ -574,45 +460,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) // handle the event OBS_FRONTEND_EVENT_RECORDING_STARTING to reset the srt sentence number // to match the subtitles with the recording - obs_frontend_add_event_callback( - [](enum obs_frontend_event event, void *private_data) { - if (event == OBS_FRONTEND_EVENT_RECORDING_STARTING) { - struct transcription_filter_data *gf_ = - static_cast( - private_data); - if (gf_->save_srt && gf_->save_only_while_recording) { - obs_log(gf_->log_level, - "Recording started. Resetting srt file."); - // truncate file if it exists - std::ofstream output_file(gf_->output_file_path, - std::ios::out | std::ios::trunc); - output_file.close(); - gf_->sentence_number = 1; - gf_->start_timestamp_ms = now_ms(); - } - } else if (event == OBS_FRONTEND_EVENT_RECORDING_STOPPED) { - struct transcription_filter_data *gf_ = - static_cast( - private_data); - if (gf_->save_srt && gf_->save_only_while_recording && - gf_->rename_file_to_match_recording) { - obs_log(gf_->log_level, - "Recording stopped. Rename srt file."); - // rename file to match the recording file name with .srt extension - // use obs_frontend_get_last_recording to get the last recording file name - std::string recording_file_name = - obs_frontend_get_last_recording(); - // remove the extension - recording_file_name = recording_file_name.substr( - 0, recording_file_name.find_last_of(".")); - std::string srt_file_name = recording_file_name + ".srt"; - // rename the file - std::rename(gf_->output_file_path.c_str(), - srt_file_name.c_str()); - } - } - }, - gf); + obs_frontend_add_event_callback(recording_state_callback, gf); obs_log(gf->log_level, "filter created."); return gf; @@ -661,15 +509,11 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_string(s, "whisper_model_path", "Whisper Tiny English (74Mb)"); obs_data_set_default_string(s, "whisper_language_select", "en"); obs_data_set_default_string(s, "subtitle_sources", "none"); - obs_data_set_default_bool(s, "step_by_step_processing", false); obs_data_set_default_bool(s, "process_while_muted", false); obs_data_set_default_bool(s, "subtitle_save_srt", false); obs_data_set_default_bool(s, "truncate_output_file", false); obs_data_set_default_bool(s, "only_while_recording", false); obs_data_set_default_bool(s, "rename_file_to_match_recording", true); - obs_data_set_default_int(s, "buffer_size_msec", DEFAULT_BUFFER_SIZE_MSEC); - obs_data_set_default_int(s, "overlap_size_msec", DEFAULT_OVERLAP_SIZE_MSEC); - obs_data_set_default_int(s, "step_size_msec", 1000); obs_data_set_default_int(s, "min_sub_duration", 3000); obs_data_set_default_bool(s, "advanced_settings", false); obs_data_set_default_bool(s, "translate", false); @@ -682,13 +526,21 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT); obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4); + // translation options + obs_data_set_default_double(s, "translation_sampling_temperature", 0.1); + obs_data_set_default_double(s, "translation_repetition_penalty", 2.0); + obs_data_set_default_int(s, "translation_beam_size", 1); + obs_data_set_default_int(s, "translation_max_decoding_length", 65); + obs_data_set_default_int(s, "translation_no_repeat_ngram_size", 1); + obs_data_set_default_int(s, "translation_max_input_length", 65); + // Whisper parameters obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); obs_data_set_default_string(s, "initial_prompt", ""); obs_data_set_default_int(s, "n_threads", 4); obs_data_set_default_int(s, "n_max_text_ctx", 16384); obs_data_set_default_bool(s, "whisper_translate", false); - obs_data_set_default_bool(s, "no_context", false); + obs_data_set_default_bool(s, "no_context", true); obs_data_set_default_bool(s, "single_segment", true); obs_data_set_default_bool(s, "print_special", false); obs_data_set_default_bool(s, "print_progress", false); @@ -700,7 +552,7 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_double(s, "thold_ptsum", 0.01); obs_data_set_default_int(s, "max_len", 0); obs_data_set_default_bool(s, "split_on_word", true); - obs_data_set_default_int(s, "max_tokens", 32); + obs_data_set_default_int(s, "max_tokens", 0); obs_data_set_default_bool(s, "speed_up", false); obs_data_set_default_bool(s, "suppress_blank", false); obs_data_set_default_bool(s, "suppress_non_speech_tokens", true); @@ -827,8 +679,18 @@ obs_properties_t *transcription_filter_properties(void *data) // input const char *new_model_path = obs_data_get_string(settings, "translate_model"); const bool is_external = (strcmp(new_model_path, "!!!external!!!") == 0); + const bool is_whisper = (strcmp(new_model_path, "whisper-based-translation") == 0); obs_property_set_visible( obs_properties_get(props, "translation_model_path_external"), is_external); + obs_property_set_visible(obs_properties_get(props, "translate_source_language"), + !is_whisper); + obs_property_set_visible(obs_properties_get(props, "translate_add_context"), + !is_whisper); + obs_property_set_visible(obs_properties_get(props, + "translate_input_tokenization_style"), + !is_whisper); + obs_property_set_visible(obs_properties_get(props, "translate_output"), + !is_whisper); return true; }); // add target language selection @@ -865,9 +727,13 @@ obs_properties_t *transcription_filter_properties(void *data) UNUSED_PARAMETER(property); // Show/Hide the translation group const bool translate_enabled = obs_data_get_bool(settings, "translate"); - for (const auto &prop : {"translate_target_language", "translate_source_language", - "translate_add_context", "translate_output", - "translate_model", "token_style"}) { + for (const auto &prop : + {"translate_target_language", "translate_source_language", + "translate_add_context", "translate_output", "translate_model", + "translate_input_tokenization_style", "translation_sampling_temperature", + "translation_repetition_penalty", "translation_beam_size", + "translation_max_decoding_length", "translation_no_repeat_ngram_size", + "translation_max_input_length"}) { obs_property_set_visible(obs_properties_get(props, prop), translate_enabled); } @@ -886,6 +752,20 @@ obs_properties_t *transcription_filter_properties(void *data) obs_property_list_add_int(prop_token_style, "M2M100 Tokens", INPUT_TOKENIZAION_M2M100); obs_property_list_add_int(prop_token_style, "T5 Tokens", INPUT_TOKENIZAION_T5); + // add translation options: beam_size, max_decoding_length, repetition_penalty, no_repeat_ngram_size, max_input_length, sampling_temperature + obs_properties_add_float_slider(translation_group, "translation_sampling_temperature", + MT_("translation_sampling_temperature"), 0.0, 1.0, 0.05); + obs_properties_add_float_slider(translation_group, "translation_repetition_penalty", + MT_("translation_repetition_penalty"), 1.0, 5.0, 0.25); + obs_properties_add_int_slider(translation_group, "translation_beam_size", + MT_("translation_beam_size"), 1, 10, 1); + obs_properties_add_int_slider(translation_group, "translation_max_decoding_length", + MT_("translation_max_decoding_length"), 1, 100, 5); + obs_properties_add_int_slider(translation_group, "translation_max_input_length", + MT_("translation_max_input_length"), 1, 100, 5); + obs_properties_add_int_slider(translation_group, "translation_no_repeat_ngram_size", + MT_("translation_no_repeat_ngram_size"), 1, 10, 1); + obs_property_t *advanced_settings_prop = obs_properties_add_bool(ppts, "advanced_settings", MT_("advanced_settings")); obs_property_set_modified_callback(advanced_settings_prop, [](obs_properties_t *props, @@ -898,7 +778,7 @@ obs_properties_t *transcription_filter_properties(void *data) {"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec", "overlap_size_msec", "step_by_step_processing", "min_sub_duration", "process_while_muted", "buffered_output", "vad_enabled", "log_level", - "suppress_sentences", "sentence_psum_accept_thresh"}) { + "suppress_sentences", "sentence_psum_accept_thresh", "send_timed_metadata"}) { obs_property_set_visible(obs_properties_get(props, prop_name.c_str()), show_hide); } @@ -907,48 +787,16 @@ obs_properties_t *transcription_filter_properties(void *data) obs_property_t *buffered_output_prop = obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output")); - // add on-change handler for buffered_output - obs_property_set_modified_callback(buffered_output_prop, [](obs_properties_t *props, - obs_property_t *property, - obs_data_t *settings) { - UNUSED_PARAMETER(property); - UNUSED_PARAMETER(props); - // if buffered output is enabled set the overlap to max else set it to default - obs_data_set_int(settings, "overlap_size_msec", - obs_data_get_bool(settings, "buffered_output") - ? MAX_OVERLAP_SIZE_MSEC - : DEFAULT_OVERLAP_SIZE_MSEC); - return true; - }); obs_properties_add_bool(ppts, "log_words", MT_("log_words")); obs_properties_add_bool(ppts, "caption_to_stream", MT_("caption_to_stream")); + obs_properties_add_bool(ppts, "send_timed_metadata", MT_("send_timed_metadata")); - obs_properties_add_int_slider(ppts, "buffer_size_msec", MT_("buffer_size_msec"), 1000, - DEFAULT_BUFFER_SIZE_MSEC, 250); - obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"), - MIN_OVERLAP_SIZE_MSEC, MAX_OVERLAP_SIZE_MSEC, - (MAX_OVERLAP_SIZE_MSEC - MIN_OVERLAP_SIZE_MSEC) / 5); - - obs_property_t *step_by_step_processing = obs_properties_add_bool( - ppts, "step_by_step_processing", MT_("step_by_step_processing")); - obs_properties_add_int_slider(ppts, "step_size_msec", MT_("step_size_msec"), 1000, - DEFAULT_BUFFER_SIZE_MSEC, 50); obs_properties_add_int_slider(ppts, "min_sub_duration", MT_("min_sub_duration"), 1000, 5000, 50); obs_properties_add_float_slider(ppts, "sentence_psum_accept_thresh", MT_("sentence_psum_accept_thresh"), 0.0, 1.0, 0.05); - obs_property_set_modified_callback(step_by_step_processing, [](obs_properties_t *props, - obs_property_t *property, - obs_data_t *settings) { - UNUSED_PARAMETER(property); - // Show/Hide the step size input - obs_property_set_visible(obs_properties_get(props, "step_size_msec"), - obs_data_get_bool(settings, "step_by_step_processing")); - return true; - }); - obs_properties_add_bool(ppts, "process_while_muted", MT_("process_while_muted")); obs_properties_add_bool(ppts, "vad_enabled", MT_("vad_enabled")); diff --git a/src/transcription-utils.cpp b/src/transcription-utils.cpp index ca9e0f1..415b47b 100644 --- a/src/transcription-utils.cpp +++ b/src/transcription-utils.cpp @@ -117,3 +117,23 @@ std::vector split(const std::string &string, char delimiter) } return tokens; } + +std::vector split_words(const std::string &str_copy) +{ + std::vector words; + std::string word; + for (char c : str_copy) { + if (std::isspace(c)) { + if (!word.empty()) { + words.push_back(word); + word.clear(); + } + } else { + word += c; + } + } + if (!word.empty()) { + words.push_back(word); + } + return words; +} diff --git a/src/transcription-utils.h b/src/transcription-utils.h index e5eb274..4e7f39c 100644 --- a/src/transcription-utils.h +++ b/src/transcription-utils.h @@ -4,36 +4,17 @@ #include #include #include -#include +// Fix UTF8 string for Windows std::string fix_utf8(const std::string &str); + +// Remove leading and trailing non-alphabetic characters std::string remove_leading_trailing_nonalpha(const std::string &str); -std::vector split(const std::string &string, char delimiter); -inline enum speaker_layout convert_speaker_layout(uint8_t channels) -{ - switch (channels) { - case 0: - return SPEAKERS_UNKNOWN; - case 1: - return SPEAKERS_MONO; - case 2: - return SPEAKERS_STEREO; - case 3: - return SPEAKERS_2POINT1; - case 4: - return SPEAKERS_4POINT0; - case 5: - return SPEAKERS_4POINT1; - case 6: - return SPEAKERS_5POINT1; - case 8: - return SPEAKERS_7POINT1; - default: - return SPEAKERS_UNKNOWN; - } -} +// Split a string by a delimiter +std::vector split(const std::string &string, char delimiter); +// Get the current timestamp in milliseconds since epoch inline uint64_t now_ms() { return std::chrono::duration_cast( @@ -41,4 +22,7 @@ inline uint64_t now_ms() .count(); } +// Split a string into words based on spaces +std::vector split_words(const std::string &str_copy); + #endif // TRANSCRIPTION_UTILS_H diff --git a/src/translation/translation-includes.h b/src/translation/translation-includes.h new file mode 100644 index 0000000..6520389 --- /dev/null +++ b/src/translation/translation-includes.h @@ -0,0 +1,8 @@ +#ifndef TRANSLATION_INCLUDES_H +#define TRANSLATION_INCLUDES_H + +#include +#include +#include + +#endif // TRANSLATION_INCLUDES_H diff --git a/src/translation/translation-utils.cpp b/src/translation/translation-utils.cpp index 439a783..07ca268 100644 --- a/src/translation/translation-utils.cpp +++ b/src/translation/translation-utils.cpp @@ -1,12 +1,11 @@ +#include -#include "translation-utils.h" - +#include "translation-includes.h" #include "translation.h" +#include "translation-utils.h" #include "plugin-support.h" #include "model-utils/model-downloader.h" -#include - void start_translation(struct transcription_filter_data *gf) { obs_log(LOG_INFO, "Starting translation..."); diff --git a/src/translation/translation-utils.h b/src/translation/translation-utils.h index 1305ff0..8a06ab4 100644 --- a/src/translation/translation-utils.h +++ b/src/translation/translation-utils.h @@ -1,4 +1,8 @@ +#ifndef TRANSLATION_UTILS_H +#define TRANSLATION_UTILS_H #include "transcription-filter-data.h" void start_translation(struct transcription_filter_data *gf); + +#endif // TRANSLATION_UTILS_H diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp index 5fa8ed2..512c810 100644 --- a/src/translation/translation.cpp +++ b/src/translation/translation.cpp @@ -70,11 +70,11 @@ int build_translation_context(struct translation_context &translation_ctx) translation_ctx.options.reset(new ctranslate2::TranslationOptions); translation_ctx.options->beam_size = 1; - translation_ctx.options->max_decoding_length = 40; - translation_ctx.options->use_vmap = true; - translation_ctx.options->return_scores = false; - translation_ctx.options->repetition_penalty = 1.1f; - translation_ctx.options->no_repeat_ngram_size = 2; + translation_ctx.options->max_decoding_length = 64; + translation_ctx.options->repetition_penalty = 2.0f; + translation_ctx.options->no_repeat_ngram_size = 1; + translation_ctx.options->max_input_length = 64; + translation_ctx.options->sampling_temperature = 0.1f; } catch (std::exception &e) { obs_log(LOG_ERROR, "Failed to load CT2 model: %s", e.what()); return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; diff --git a/src/translation/translation.h b/src/translation/translation.h index bf99f42..1b601fc 100644 --- a/src/translation/translation.h +++ b/src/translation/translation.h @@ -1,13 +1,22 @@ -#pragma once +#ifndef TRANSLATION_H +#define TRANSLATION_H -#include -#include #include #include #include +#include enum InputTokenizationStyle { INPUT_TOKENIZAION_M2M100 = 0, INPUT_TOKENIZAION_T5 }; +namespace ctranslate2 { +class Translator; +class TranslationOptions; +} // namespace ctranslate2 + +namespace sentencepiece { +class SentencePieceProcessor; +} // namespace sentencepiece + struct translation_context { std::string local_model_folder_path; std::unique_ptr processor; @@ -33,3 +42,5 @@ int translate(struct translation_context &translation_ctx, const std::string &te #define OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS 0 #define OBS_POLYGLOT_TRANSLATION_SUCCESS 0 #define OBS_POLYGLOT_TRANSLATION_FAIL -1 + +#endif // TRANSLATION_H diff --git a/src/utils.cpp b/src/utils.cpp deleted file mode 100644 index 6639ae7..0000000 --- a/src/utils.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "utils.h" - -std::vector split_words(const std::string &str_copy) -{ - std::vector words; - std::string word; - for (char c : str_copy) { - if (std::isspace(c)) { - if (!word.empty()) { - words.push_back(word); - word.clear(); - } - } else { - word += c; - } - } - if (!word.empty()) { - words.push_back(word); - } - return words; -} diff --git a/src/utils.h b/src/utils.h deleted file mode 100644 index 9348417..0000000 --- a/src/utils.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef UTILS_H -#define UTILS_H - -#include -#include - -std::vector split_words(const std::string &str_copy); - -#endif // UTILS_H diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index 13d2ffc..02bddb3 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -1,131 +1,215 @@ +#include +#include +#include + #include "token-buffer-thread.h" -#include "./whisper-utils.h" +#include "whisper-utils.h" + +#include + +#ifdef _WIN32 +#include +#define SPACE L" " +#define NEWLINE L"\n" +#else +#define SPACE " " +#define NEWLINE "\n" +#endif TokenBufferThread::~TokenBufferThread() { { - std::lock_guard lock(queueMutex); + std::lock_guard lock(inputQueueMutex); stop = true; } condVar.notify_all(); - workerThread.join(); + if (workerThread.joinable()) { + workerThread.join(); + } } void TokenBufferThread::initialize(struct transcription_filter_data *gf_, std::function callback_, - size_t maxSize_, std::chrono::seconds maxTime_) + size_t numSentences_, size_t numPerSentence_, + std::chrono::seconds maxTime_, + TokenBufferSegmentation segmentation_) { this->gf = gf_; this->callback = callback_; - this->maxSize = maxSize_; + this->numSentences = numSentences_; + this->numPerSentence = numPerSentence_; + this->segmentation = segmentation_; this->maxTime = maxTime_; - this->initialized = true; + this->stop = false; this->workerThread = std::thread(&TokenBufferThread::monitor, this); } -void TokenBufferThread::log_token_vector(const std::vector &tokens) +void TokenBufferThread::stopThread() +{ + std::lock_guard lock(inputQueueMutex); + stop = true; + condVar.notify_all(); + if (workerThread.joinable()) { + workerThread.join(); + } +} + +void TokenBufferThread::log_token_vector(const std::vector &tokens) { std::string output; for (const auto &token : tokens) { - const char *token_str = whisper_token_to_str(gf->whisper_context, token.id); - output += token_str; + output += token; } obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); } -void TokenBufferThread::addWords(const std::vector &words) +void TokenBufferThread::addSentence(const std::string &sentence) { - obs_log(LOG_INFO, "TokenBufferThread::addWords"); - { - std::lock_guard lock(queueMutex); - - // convert current wordQueue to vector - std::vector currentWords(wordQueue.begin(), wordQueue.end()); - - log_token_vector(currentWords); - log_token_vector(words); - - // run reconstructSentence - std::vector reconstructed = - reconstructSentence(currentWords, words); +#ifdef _WIN32 + // on windows convert from multibyte to wide char + int count = + MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), NULL, 0); + std::wstring sentence_ws(count, 0); + MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), &sentence_ws[0], + count); + // split to characters + std::vector characters; + for (const auto &c : sentence_ws) { + characters.push_back(std::wstring(1, c)); + } +#else + // split to characters + std::vector characters; + for (const auto &c : sentence_ws) { + characters.push_back(std::string(1, c)); + } +#endif - log_token_vector(reconstructed); + std::lock_guard lock(inputQueueMutex); - // clear the wordQueue - wordQueue.clear(); + // if the inputqueue and sentence don't have a space in them, add a space + if (!inputQueue.empty() && !sentence.empty() && inputQueue.back() != SPACE && + characters.front() != SPACE) { + inputQueue.push_back(SPACE); + } - // add the reconstructed sentence to the wordQueue - for (const auto &word : reconstructed) { - wordQueue.push_back(word); - } + // add the reconstructed sentence to the wordQueue + for (const auto &character : characters) { + inputQueue.push_back(character); + } +} - newDataAvailable = true; +void TokenBufferThread::clear() +{ + { + std::lock_guard lock(inputQueueMutex); + inputQueue.clear(); } - condVar.notify_all(); + { + std::lock_guard lock(presentationQueueMutex); + presentationQueue.clear(); + } + this->callback(""); } void TokenBufferThread::monitor() { obs_log(LOG_INFO, "TokenBufferThread::monitor"); - auto startTime = std::chrono::steady_clock::now(); - while (this->initialized && !this->stop) { - std::unique_lock lock(this->queueMutex); - // wait for new data or stop signal - this->condVar.wait(lock, [this] { return this->newDataAvailable || this->stop; }); - - if (this->stop) { - break; - } - - if (this->wordQueue.empty()) { - continue; - } - if (this->gf->whisper_context == nullptr) { - continue; - } - - // emit up to maxSize words from the wordQueue - std::vector emitted; - while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) { - emitted.push_back(this->wordQueue.front()); - this->wordQueue.pop_front(); - } - obs_log(LOG_INFO, "TokenBufferThread::monitor: emitting %d words", emitted.size()); - log_token_vector(emitted); - // emit the caption from the tokens - std::string output; - for (const auto &token : emitted) { - const char *token_str = - whisper_token_to_str(this->gf->whisper_context, token.id); - output += token_str; - } - this->callback(output); - // push back the words that were emitted, in reverse order - for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) { - this->wordQueue.push_front(*it); - } + this->callback(""); + + while (!this->stop) { + { + std::unique_lock lockPresentation(this->presentationQueueMutex); + // condition presentation queue + if (presentationQueue.size() == this->numSentences * this->numPerSentence) { + // pop a whole sentence from the presentation queue front + for (size_t i = 0; i < this->numPerSentence; i++) { + presentationQueue.pop_front(); + } + } - // check if we need to flush the queue - auto elapsedTime = std::chrono::duration_cast( - std::chrono::steady_clock::now() - startTime); - if (this->wordQueue.size() >= this->maxSize || elapsedTime >= this->maxTime) { - // flush the queue if it's full or we've reached the max time - size_t words_to_flush = std::min(this->wordQueue.size(), this->maxSize); - // make sure we leave at least 3 words in the queue - size_t words_remaining = this->wordQueue.size() - words_to_flush; - if (words_remaining < 3) { - words_to_flush -= 3 - words_remaining; + { + std::unique_lock lock(this->inputQueueMutex); + + if (!inputQueue.empty()) { + // if there are token on the input queue + if (this->segmentation == SEGMENTATION_SENTENCE) { + // add all the tokens from the input queue to the presentation queue + for (const auto &token : inputQueue) { + presentationQueue.push_back(token); + } + } else { + // add one token to the presentation queue + presentationQueue.push_back(inputQueue.front()); + inputQueue.pop_front(); + } + } } - obs_log(LOG_INFO, "TokenBufferThread::monitor: flushing %d words", - words_to_flush); - for (size_t i = 0; i < words_to_flush; ++i) { - wordQueue.pop_front(); + + if (presentationQueue.size() > 0) { + // build a caption from the presentation queue in sentences + // with a maximum of numPerSentence tokens/words per sentence + // and a newline between sentences + TokenBufferString caption; + if (this->segmentation == SEGMENTATION_WORD) { + // iterate through the presentation queue tokens and make words (based on spaces) + // then build a caption with a maximum of numPerSentence words per sentence + size_t wordsInSentence = 0; + TokenBufferString word; + for (const auto &token : presentationQueue) { + // keep adding tokens to the word until a space is found + word += token; + if (word.find(SPACE) != TokenBufferString::npos) { + // cut the word at the space and add it to the caption + caption += word.substr(0, word.find(SPACE)); + wordsInSentence++; + // keep the rest of the word for the next iteration + word = word.substr(word.find(SPACE) + 1); + + if (wordsInSentence == + this->numPerSentence) { + caption += word; + caption += SPACE; + wordsInSentence = 0; + word.clear(); + } + } + } + } else { + // iterate through the presentation queue tokens and build a caption + size_t tokensInSentence = 0; + for (const auto &token : presentationQueue) { + caption += token; + tokensInSentence++; + if (tokensInSentence == this->numPerSentence) { + caption += NEWLINE; + tokensInSentence = 0; + } + } + } + +#ifdef _WIN32 + // convert caption to multibyte for obs + int count = WideCharToMultiByte(CP_UTF8, 0, caption.c_str(), + (int)caption.length(), NULL, 0, + NULL, NULL); + std::string caption_out(count, 0); + WideCharToMultiByte(CP_UTF8, 0, caption.c_str(), + (int)caption.length(), &caption_out[0], count, + NULL, NULL); +#else + std::string caption_out(caption.begin(), caption.end()); +#endif + + // emit the caption + this->callback(caption_out); } - startTime = std::chrono::steady_clock::now(); } - newDataAvailable = false; + // sleep for 100 ms + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } + obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); } diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h index 1a56b70..8d73285 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/whisper-utils/token-buffer-thread.h @@ -12,12 +12,18 @@ #include -#include - #include "plugin-support.h" +#ifdef _WIN32 +typedef std::wstring TokenBufferString; +#else +typedef std::string TokenBufferString; +#endif + struct transcription_filter_data; +enum TokenBufferSegmentation { SEGMENTATION_WORD = 0, SEGMENTATION_TOKEN, SEGMENTATION_SENTENCE }; + class TokenBufferThread { public: // default constructor @@ -25,25 +31,33 @@ class TokenBufferThread { ~TokenBufferThread(); void initialize(struct transcription_filter_data *gf, - std::function callback_, size_t maxSize_, - std::chrono::seconds maxTime_); + std::function callback_, size_t numSentences_, + size_t numTokensPerSentence_, std::chrono::seconds maxTime_, + TokenBufferSegmentation segmentation_ = SEGMENTATION_TOKEN); + + void addSentence(const std::string &sentence); + void clear(); + void stopThread(); - void addWords(const std::vector &words); + bool isEnabled() const { return !stop; } private: void monitor(); - void log_token_vector(const std::vector &tokens); + void log_token_vector(const std::vector &tokens); struct transcription_filter_data *gf; - std::deque wordQueue; + std::deque inputQueue; + std::deque presentationQueue; std::thread workerThread; - std::mutex queueMutex; + std::mutex inputQueueMutex; + std::mutex presentationQueueMutex; std::condition_variable condVar; std::function callback; - size_t maxSize; std::chrono::seconds maxTime; - bool stop; - bool initialized = false; + bool stop = true; bool newDataAvailable = false; + size_t numSentences; + size_t numPerSentence; + TokenBufferSegmentation segmentation; }; #endif diff --git a/src/whisper-utils/whisper-model-utils.cpp b/src/whisper-utils/whisper-model-utils.cpp index cc27484..c9620c8 100644 --- a/src/whisper-utils/whisper-model-utils.cpp +++ b/src/whisper-utils/whisper-model-utils.cpp @@ -1,9 +1,13 @@ +#ifdef _WIN32 +#define NOMINMAX +#endif + +#include + #include "whisper-utils.h" +#include "whisper-processing.h" #include "plugin-support.h" #include "model-utils/model-downloader.h" -#include "whisper-processing.h" - -#include void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s) { @@ -98,9 +102,5 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s) gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps"); shutdown_whisper_thread(gf); start_whisper_thread_with_path(gf, gf->whisper_model_path, silero_vad_model_file); - } else { - // dtw_token_timestamps did not change - obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d", - gf->enable_token_ts_dtw, new_dtw_timestamps); } } diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 919eb3a..9d01762 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -18,6 +18,7 @@ #include #include +#include struct vad_state { bool vad_on; @@ -275,192 +276,14 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter return {DETECTION_RESULT_SILENCE, "", t0, t1, {}}; } - return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens}; - } -} - -void process_audio_from_buffer(struct transcription_filter_data *gf) -{ - uint32_t num_new_frames_from_infos = 0; - uint64_t start_timestamp = 0; - bool save_overlap_region = true; - - { - // scoped lock the buffer mutex - std::lock_guard lock(gf->whisper_buf_mutex); - - // We need (gf->frames - gf->last_num_frames) new frames for a full segment, - const size_t remaining_frames_to_full_segment = gf->frames - gf->last_num_frames; - - obs_log(gf->log_level, - "processing audio from buffer, %lu existing frames, %lu frames needed to full segment (%d frames)", - gf->last_num_frames, remaining_frames_to_full_segment, gf->frames); - - // pop infos from the info buffer and mark the beginning timestamp from the first - // info as the beginning timestamp of the segment - struct transcription_filter_audio_info info_from_buf = {0}; - const size_t size_of_audio_info = sizeof(struct transcription_filter_audio_info); - while (gf->info_buffer.size >= size_of_audio_info) { - circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); - num_new_frames_from_infos += info_from_buf.frames; - if (start_timestamp == 0) { - start_timestamp = info_from_buf.timestamp; - } - // Check if we're within the needed segment length - if (num_new_frames_from_infos > remaining_frames_to_full_segment) { - // too big, push the last info into the buffer's front where it was - num_new_frames_from_infos -= info_from_buf.frames; - circlebuf_push_front(&gf->info_buffer, &info_from_buf, - size_of_audio_info); - break; - } - } - - obs_log(gf->log_level, - "with %lu remaining to full segment, popped %d frames from info buffer, pushed at %lu (overlap)", - remaining_frames_to_full_segment, num_new_frames_from_infos, - gf->last_num_frames); - - /* Pop from input circlebuf */ - for (size_t c = 0; c < gf->channels; c++) { - // Push the new data to the end of the existing buffer copy_buffers[c] - circlebuf_pop_front(&gf->input_buffers[c], - gf->copy_buffers[c] + gf->last_num_frames, - num_new_frames_from_infos * sizeof(float)); - } - } - - if (gf->last_num_frames > 0) { - obs_log(gf->log_level, "full segment, %lu frames overlap, %lu frames to process", - gf->last_num_frames, gf->last_num_frames + num_new_frames_from_infos); - gf->last_num_frames += num_new_frames_from_infos; - } else { - gf->last_num_frames = num_new_frames_from_infos; - obs_log(gf->log_level, "first segment, no overlap exists, %lu frames to process", - gf->last_num_frames); - } - - obs_log(gf->log_level, "processing %lu frames (%d ms), start timestamp %llu", - gf->last_num_frames, - (int)((float)gf->last_num_frames * 1000.0f / (float)gf->sample_rate), - start_timestamp); - - // time the audio processing - auto start = std::chrono::high_resolution_clock::now(); - - // resample to 16kHz - float *resampled_16khz[MAX_PREPROC_CHANNELS]; - uint32_t resampled_16khz_frames; - uint64_t ts_offset; - audio_resampler_resample(gf->resampler_to_whisper, (uint8_t **)resampled_16khz, - &resampled_16khz_frames, &ts_offset, - (const uint8_t **)gf->copy_buffers, (uint32_t)gf->last_num_frames); - - obs_log(gf->log_level, "%d channels, %d frames, %f ms", (int)gf->channels, - (int)resampled_16khz_frames, - (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f); - - bool skipped_inference = false; - uint32_t speech_start_frame = 0; - uint32_t speech_end_frame = resampled_16khz_frames; - - if (gf->vad_enabled) { - std::vector vad_input(resampled_16khz[0], - resampled_16khz[0] + resampled_16khz_frames); - gf->vad->process(vad_input, false); - - std::vector stamps = gf->vad->get_speech_timestamps(); - if (stamps.size() == 0) { - obs_log(gf->log_level, "VAD detected no speech in %d frames", - resampled_16khz_frames); - skipped_inference = true; - // prevent copying the buffer to the beginning (overlap) - save_overlap_region = false; - } else { - // if the vad finds that start within the first 10% of the buffer, set the start to 0 - speech_start_frame = (stamps[0].start < (int)(resampled_16khz_frames / 10)) - ? 0 - : stamps[0].start; - speech_end_frame = stamps.back().end; - uint32_t number_of_frames = speech_end_frame - speech_start_frame; - - // if the speech is pressed up against the end of the buffer - // apply the overlapped region, else don't - save_overlap_region = (speech_end_frame == resampled_16khz_frames); - - obs_log(gf->log_level, - "VAD detected speech from %d to %d (%d frames, %d ms)", - speech_start_frame, speech_end_frame, number_of_frames, - number_of_frames * 1000 / WHISPER_SAMPLE_RATE); - - // if the speech is less than 1 second - pad with zeros and send for inference - if (number_of_frames > 0 && number_of_frames < WHISPER_SAMPLE_RATE) { - obs_log(gf->log_level, - "Speech segment is less than 1 second, padding with zeros to 1 second"); - // copy the speech segment to the beginning of the resampled buffer - // use memmove to copy the speech segment to the beginning of the buffer - memmove(resampled_16khz[0], resampled_16khz[0] + speech_start_frame, - number_of_frames * sizeof(float)); - // zero out the rest of the buffer - memset(resampled_16khz[0] + number_of_frames, 0, - (WHISPER_SAMPLE_RATE - number_of_frames) * sizeof(float)); - - speech_start_frame = 0; - speech_end_frame = WHISPER_SAMPLE_RATE; - } - } - } - - if (!skipped_inference) { - // run inference - const struct DetectionResultWithText inference_result = - run_whisper_inference(gf, resampled_16khz[0] + speech_start_frame, - speech_end_frame - speech_start_frame); - - if (inference_result.result == DETECTION_RESULT_SPEECH) { - // output inference result to a text source - set_text_callback(gf, inference_result); - } else if (inference_result.result == DETECTION_RESULT_SILENCE) { - // output inference result to a text source - set_text_callback(gf, {inference_result.result, "[silence]", 0, 0, {}}); - } - } else { - if (gf->log_words) { - obs_log(LOG_INFO, "skipping inference"); + // Check regex for "MBC .*" to detect false prediction + std::regex mbc_regex("MBC.*"); + if (std::regex_match(text, mbc_regex)) { + obs_log(gf->log_level, "False prediction detected: %s", text.c_str()); + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}}; } - set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0, {}}); - } - - // end of timer - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start).count(); - const uint64_t last_num_frames_ms = gf->last_num_frames * 1000 / gf->sample_rate; - obs_log(gf->log_level, "audio processing of %lu ms data took %d ms", last_num_frames_ms, - (int)duration); - if (save_overlap_region) { - const uint64_t overlap_size_ms = - (uint64_t)(gf->overlap_frames * 1000 / gf->sample_rate); - obs_log(gf->log_level, - "copying %lu overlap frames (%lu ms) from the end of the buffer (pos %lu) to the beginning", - gf->overlap_frames, overlap_size_ms, - gf->last_num_frames - gf->overlap_frames); - for (size_t c = 0; c < gf->channels; c++) { - // zero out the copy buffer, just in case - memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); - // move overlap frames from the end of the last copy_buffers to the beginning - memmove(gf->copy_buffers[c], - gf->copy_buffers[c] + gf->last_num_frames - gf->overlap_frames, - gf->overlap_frames * sizeof(float)); - } - gf->last_num_frames = gf->overlap_frames; - } else { - obs_log(gf->log_level, "no overlap needed. zeroing out the copy buffer"); - // zero out the copy buffer, just in case - for (size_t c = 0; c < gf->channels; c++) { - memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); - } - gf->last_num_frames = 0; + return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens}; } } @@ -468,18 +291,22 @@ void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_o uint64_t end_offset_ms, int vad_state) { // get the data from the entire whisper buffer + // add 50ms of silence to the beginning and end of the buffer const size_t pcm32f_size = gf->whisper_buffer.size / sizeof(float); + const size_t pcm32f_size_with_silence = pcm32f_size + 2 * WHISPER_SAMPLE_RATE / 100; // allocate a new buffer and copy the data to it - float *pcm32f_data = (float *)bzalloc(pcm32f_size * sizeof(float)); - circlebuf_pop_back(&gf->whisper_buffer, pcm32f_data, pcm32f_size * sizeof(float)); + float *pcm32f_data = (float *)bzalloc(pcm32f_size_with_silence * sizeof(float)); + circlebuf_pop_back(&gf->whisper_buffer, pcm32f_data + WHISPER_SAMPLE_RATE / 100, + pcm32f_size * sizeof(float)); - struct DetectionResultWithText inference_result = - run_whisper_inference(gf, pcm32f_data, pcm32f_size, start_offset_ms, end_offset_ms); + struct DetectionResultWithText inference_result = run_whisper_inference( + gf, pcm32f_data, pcm32f_size_with_silence, start_offset_ms, end_offset_ms); // output inference result to a text source set_text_callback(gf, inference_result); if (gf->enable_audio_chunks_callback) { - audio_chunk_callback(gf, pcm32f_data, pcm32f_size, vad_state, inference_result); + audio_chunk_callback(gf, pcm32f_data, pcm32f_size_with_silence, vad_state, + inference_result); } // free the buffer From 4ec47d264c3721f3ad37061e20510ff43fa4c19a Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 3 Jun 2024 17:24:21 -0400 Subject: [PATCH 02/12] refactor: Add buffered output parameters for transcription filter --- data/locale/en-US.ini | 3 ++ src/transcription-filter-callbacks.cpp | 4 --- src/transcription-filter-data.h | 4 ++- src/transcription-filter.cpp | 39 ++++++++++++++++++++++- src/whisper-utils/token-buffer-thread.cpp | 17 +++++----- src/whisper-utils/token-buffer-thread.h | 3 ++ 6 files changed, 56 insertions(+), 14 deletions(-) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index c9b8757..5507446 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -67,3 +67,6 @@ translation_beam_size="Beam size" translation_max_decoding_length="Max decoding length" translation_no_repeat_ngram_size="No-repeat ngram size" translation_max_input_length="Max input length" +buffered_output_parameters="Buffered output parameters" +buffer_num_lines="Number of lines" +buffer_num_chars_per_line="Characters per line" diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 6d0e062..f8d729a 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -135,10 +135,6 @@ void set_text_callback(struct transcription_filter_data *gf, } } - if (gf->send_timed_metadata) { - send_timed_metadata(gf, result); - } - if (gf->output_file_path != "" && gf->text_source_name.empty()) { // Check if we should save the sentence if (gf->save_only_while_recording && !obs_frontend_recording_active()) { diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index dc56b1a..b52fa09 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -74,7 +74,6 @@ struct transcription_filter_data { std::string source_lang; std::string target_lang; std::string translation_output; - bool buffered_output = false; bool enable_token_ts_dtw = false; std::string suppress_sentences; bool fix_utf8 = true; @@ -104,7 +103,10 @@ struct transcription_filter_data { std::string translation_model_index; std::string translation_model_path_external; + bool buffered_output = false; TokenBufferThread captions_monitor; + int buffered_output_num_lines = 2; + int buffered_output_num_chars = 30; // ctor transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 1f8025f..ed5caa9 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -231,6 +231,8 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration"); gf->last_sub_render_time = 0; bool new_buffered_output = obs_data_get_bool(s, "buffered_output"); + int new_buffer_num_lines = obs_data_get_int(s, "buffer_num_lines"); + int new_buffer_num_chars_per_line = obs_data_get_int(s, "buffer_num_chars_per_line"); if (new_buffered_output) { obs_log(LOG_INFO, "buffered_output enable"); @@ -245,7 +247,16 @@ void transcription_filter_update(void *data, obs_data_t *s) gf); } }, - 2, 30, std::chrono::seconds(10)); + new_buffer_num_lines, new_buffer_num_chars_per_line, + std::chrono::seconds(10)); + } else { + if (new_buffer_num_lines != gf->buffered_output_num_lines || + new_buffer_num_chars_per_line != gf->buffered_output_num_chars) { + obs_log(LOG_INFO, "buffered_output parameters changed, updating"); + gf->captions_monitor.setNumSentences(new_buffer_num_lines); + gf->captions_monitor.setNumPerSentence( + new_buffer_num_chars_per_line); + } } } else { obs_log(LOG_INFO, "buffered_output disable"); @@ -502,6 +513,9 @@ void transcription_filter_defaults(obs_data_t *s) obs_log(LOG_INFO, "filter defaults"); obs_data_set_default_bool(s, "buffered_output", false); + obs_data_set_default_int(s, "buffer_num_lines", 2); + obs_data_set_default_int(s, "buffer_num_chars_per_line", 30); + obs_data_set_default_bool(s, "vad_enabled", true); obs_data_set_default_int(s, "log_level", LOG_DEBUG); obs_data_set_default_bool(s, "log_words", false); @@ -788,6 +802,29 @@ obs_properties_t *transcription_filter_properties(void *data) obs_property_t *buffered_output_prop = obs_properties_add_bool(ppts, "buffered_output", MT_("buffered_output")); + // add buffered output options group + obs_properties_t *buffered_output_group = obs_properties_create(); + obs_properties_add_group(ppts, "buffered_output_group", MT_("buffered_output_parameters"), + OBS_GROUP_NORMAL, buffered_output_group); + // add buffer lines parameter + obs_properties_add_int_slider(buffered_output_group, "buffer_num_lines", + MT_("buffer_num_lines"), 1, 5, 1); + // add buffer number of characters per line parameter + obs_properties_add_int_slider(buffered_output_group, "buffer_num_chars_per_line", + MT_("buffer_num_chars_per_line"), 1, 100, 1); + + // on enable/disable buffered output, show/hide the group + obs_property_set_modified_callback(buffered_output_prop, [](obs_properties_t *props, + obs_property_t *property, + obs_data_t *settings) { + UNUSED_PARAMETER(property); + // If buffered output is enabled, show the buffered output group + const bool show_hide = obs_data_get_bool(settings, "buffered_output"); + obs_property_set_visible(obs_properties_get(props, "buffered_output_group"), + show_hide); + return true; + }); + obs_properties_add_bool(ppts, "log_words", MT_("log_words")); obs_properties_add_bool(ppts, "caption_to_stream", MT_("caption_to_stream")); obs_properties_add_bool(ppts, "send_timed_metadata", MT_("send_timed_metadata")); diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index 02bddb3..8f95a7d 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -87,16 +87,11 @@ void TokenBufferThread::addSentence(const std::string &sentence) std::lock_guard lock(inputQueueMutex); - // if the inputqueue and sentence don't have a space in them, add a space - if (!inputQueue.empty() && !sentence.empty() && inputQueue.back() != SPACE && - characters.front() != SPACE) { - inputQueue.push_back(SPACE); - } - // add the reconstructed sentence to the wordQueue for (const auto &character : characters) { inputQueue.push_back(character); } + inputQueue.push_back(SPACE); } void TokenBufferThread::clear() @@ -180,6 +175,11 @@ void TokenBufferThread::monitor() // iterate through the presentation queue tokens and build a caption size_t tokensInSentence = 0; for (const auto &token : presentationQueue) { + // skip spaces in the beginning of a sentence (tokensInSentence == 0) + if (token == SPACE && tokensInSentence == 0) { + continue; + } + caption += token; tokensInSentence++; if (tokensInSentence == this->numPerSentence) { @@ -207,8 +207,9 @@ void TokenBufferThread::monitor() } } - // sleep for 100 ms - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + // check the presentation queue size (pqs), if it's big - sleep less + std::this_thread::sleep_for( + std::chrono::milliseconds(presentationQueue.size() > 15 ? 66 : 100)); } obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h index 8d73285..a318a0e 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/whisper-utils/token-buffer-thread.h @@ -41,6 +41,9 @@ class TokenBufferThread { bool isEnabled() const { return !stop; } + void setNumSentences(size_t numSentences_) { numSentences = numSentences_; } + void setNumPerSentence(size_t numPerSentence_) { numPerSentence = numPerSentence_; } + private: void monitor(); void log_token_vector(const std::vector &tokens); From 10d0d0a49527ea3dc6bcb1c8258eb6dcc68c246f Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 3 Jun 2024 22:11:16 -0400 Subject: [PATCH 03/12] refactor: Remove unused parameter in set_source_signals function --- src/transcription-filter.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index ed5caa9..5d47df4 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -73,6 +73,7 @@ void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_sour signal_handler_connect( sh, "media_play", [](void *data_, calldata_t *cd) { + UNUSED_PARAMETER(cd); transcription_filter_data *gf_ = static_cast(data_); obs_log(gf_->log_level, "media_play"); @@ -82,6 +83,7 @@ void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_sour signal_handler_connect( sh, "media_started", [](void *data_, calldata_t *cd) { + UNUSED_PARAMETER(cd); transcription_filter_data *gf_ = static_cast(data_); obs_log(gf_->log_level, "media_started"); @@ -92,6 +94,7 @@ void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_sour signal_handler_connect( sh, "media_pause", [](void *data_, calldata_t *cd) { + UNUSED_PARAMETER(cd); transcription_filter_data *gf_ = static_cast(data_); obs_log(gf_->log_level, "media_pause"); @@ -101,6 +104,7 @@ void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_sour signal_handler_connect( sh, "media_restart", [](void *data_, calldata_t *cd) { + UNUSED_PARAMETER(cd); transcription_filter_data *gf_ = static_cast(data_); obs_log(gf_->log_level, "media_restart"); @@ -111,6 +115,7 @@ void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_sour signal_handler_connect( sh, "media_stopped", [](void *data_, calldata_t *cd) { + UNUSED_PARAMETER(cd); transcription_filter_data *gf_ = static_cast(data_); obs_log(gf_->log_level, "media_stopped"); @@ -231,8 +236,8 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration"); gf->last_sub_render_time = 0; bool new_buffered_output = obs_data_get_bool(s, "buffered_output"); - int new_buffer_num_lines = obs_data_get_int(s, "buffer_num_lines"); - int new_buffer_num_chars_per_line = obs_data_get_int(s, "buffer_num_chars_per_line"); + int new_buffer_num_lines = (int)obs_data_get_int(s, "buffer_num_lines"); + int new_buffer_num_chars_per_line = (int)obs_data_get_int(s, "buffer_num_chars_per_line"); if (new_buffered_output) { obs_log(LOG_INFO, "buffered_output enable"); @@ -256,6 +261,8 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->captions_monitor.setNumSentences(new_buffer_num_lines); gf->captions_monitor.setNumPerSentence( new_buffer_num_chars_per_line); + gf->buffered_output_num_lines = new_buffer_num_lines; + gf->buffered_output_num_chars = new_buffer_num_chars_per_line; } } } else { From 255415c6a4f916a9eeb42ccd3bc64d7c9f5b164f Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 3 Jun 2024 22:21:38 -0400 Subject: [PATCH 04/12] refactor: Fix character splitting bug in TokenBufferThread --- src/whisper-utils/token-buffer-thread.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index 8f95a7d..38f77db 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -80,7 +80,7 @@ void TokenBufferThread::addSentence(const std::string &sentence) #else // split to characters std::vector characters; - for (const auto &c : sentence_ws) { + for (const auto &c : sentence) { characters.push_back(std::string(1, c)); } #endif From a0753a4cfd3bfa15b4a442bcde5f0bdcc3208bff Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 3 Jun 2024 23:22:47 -0400 Subject: [PATCH 05/12] refactor: Update buffer size and overlap size in whisper-processing.cpp --- src/whisper-utils/token-buffer-thread.cpp | 4 +-- src/whisper-utils/whisper-processing.cpp | 42 +++++++++++------------ 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index 38f77db..d78dfa1 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -207,9 +207,9 @@ void TokenBufferThread::monitor() } } - // check the presentation queue size (pqs), if it's big - sleep less + // check the input queue size (iqs), if it's big - sleep less std::this_thread::sleep_for( - std::chrono::milliseconds(presentationQueue.size() > 15 ? 66 : 100)); + std::chrono::milliseconds(inputQueue.size() > 15 ? 66 : 100)); } obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 9d01762..a0cf141 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -161,15 +161,18 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter bool should_free_buffer = false; float *pcm32f_data = (float *)pcm32f_data_; size_t pcm32f_size = pcm32f_num_samples; + const uint64_t original_duration_ms = + (uint64_t)(pcm32f_num_samples * 1000 / WHISPER_SAMPLE_RATE); if (pcm32f_num_samples < WHISPER_SAMPLE_RATE) { obs_log(gf->log_level, "Speech segment is less than 1 second, padding with zeros to 1 second"); const size_t new_size = (size_t)(1.01f * (float)(WHISPER_SAMPLE_RATE)); - // create a new buffer and copy the data to it + // create a new buffer and copy the data to it in the middle pcm32f_data = (float *)bzalloc(new_size * sizeof(float)); memset(pcm32f_data, 0, new_size * sizeof(float)); - memcpy(pcm32f_data, pcm32f_data_, pcm32f_num_samples * sizeof(float)); + memcpy(pcm32f_data + (new_size - pcm32f_num_samples) / 2, pcm32f_data_, + pcm32f_num_samples * sizeof(float)); pcm32f_size = new_size; should_free_buffer = true; } @@ -231,23 +234,27 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter if (token.id >= 50256) { keep = false; } - // if (j == n_tokens - 2 && token.p < 0.5) { - // keep = false; - // } - // if (j == n_tokens - 3 && token.p < 0.4) { - // keep = false; - // } // if the second to last token is .id == 13 ('.'), don't keep it if (j == n_tokens - 2 && token.id == 13) { keep = false; } // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json - // if (token.id > 50566 && token.id <= 51865) { - // obs_log(gf->log_level, - // "Large time token found (%d), this shouldn't happen", - // token.id); - // return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; - // } + if (token.id > 50365 && token.id <= 51865) { + const float time = ((float)token.id - 50365.0f) * 0.02; + const float duration_s = (float)duration_ms / 1000.0f; + const float ratio = std::max(time, duration_s) / + std::min(time, duration_s); + obs_log(gf->log_level, + "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f.", + token.id, time, duration_s, ratio); + if (ratio > 3.0f) { + // ratio is too high, skip this detection + obs_log(gf->log_level, + "Time token ratio too high, skipping"); + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}}; + } + keep = false; + } if (keep) { sentence_p += token.p; @@ -276,13 +283,6 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter return {DETECTION_RESULT_SILENCE, "", t0, t1, {}}; } - // Check regex for "MBC .*" to detect false prediction - std::regex mbc_regex("MBC.*"); - if (std::regex_match(text, mbc_regex)) { - obs_log(gf->log_level, "False prediction detected: %s", text.c_str()); - return {DETECTION_RESULT_SILENCE, "", t0, t1, {}}; - } - return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens}; } } From c69f6eb0ef2094c7b16dd02746b3862aca421812 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 3 Jun 2024 23:25:54 -0400 Subject: [PATCH 06/12] refactor: Remove unused parameter in set_source_signals function --- src/whisper-utils/whisper-processing.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index a0cf141..11237cd 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -161,8 +161,6 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter bool should_free_buffer = false; float *pcm32f_data = (float *)pcm32f_data_; size_t pcm32f_size = pcm32f_num_samples; - const uint64_t original_duration_ms = - (uint64_t)(pcm32f_num_samples * 1000 / WHISPER_SAMPLE_RATE); if (pcm32f_num_samples < WHISPER_SAMPLE_RATE) { obs_log(gf->log_level, From 260448ed95a72003ca88e20d180bf9bf0b3b92ea Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 4 Jun 2024 09:08:12 -0400 Subject: [PATCH 07/12] refactor: Fix floating point precision issue in whisper-processing.cpp --- src/whisper-utils/whisper-processing.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 11237cd..3db6f30 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -238,7 +238,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter } // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json if (token.id > 50365 && token.id <= 51865) { - const float time = ((float)token.id - 50365.0f) * 0.02; + const float time = ((float)token.id - 50365.0f) * 0.02f; const float duration_s = (float)duration_ms / 1000.0f; const float ratio = std::max(time, duration_s) / std::min(time, duration_s); From dab16d2a64fd8f8785834f33d40b12498f44145f Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 4 Jun 2024 11:08:02 -0400 Subject: [PATCH 08/12] refactor: Improve remove_leading_trailing_nonalpha function in transcription-utils.cpp --- src/transcription-utils.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/transcription-utils.cpp b/src/transcription-utils.cpp index 415b47b..321d2fb 100644 --- a/src/transcription-utils.cpp +++ b/src/transcription-utils.cpp @@ -88,6 +88,27 @@ std::string fix_utf8(const std::string &str) */ std::string remove_leading_trailing_nonalpha(const std::string &str) { + if (str.size() == 0) { + return str; + } + if (str.size() == 1) { + if (std::isalpha(str[0])) { + return str; + } else { + return ""; + } + } + if (str.size() == 2) { + if (std::isalpha(str[0]) && std::isalpha(str[1])) { + return str; + } else if (std::isalpha(str[0])) { + return std::string(1, str[0]); + } else if (std::isalpha(str[1])) { + return std::string(1, str[1]); + } else { + return ""; + } + } std::string str_copy = str; // remove trailing spaces, newlines, tabs or punctuation auto last_non_space = From 5cf661db24c68932707fb8b11d5ce153dc32370d Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 5 Jun 2024 10:52:21 -0400 Subject: [PATCH 09/12] refactor: Update VAD threshold in transcription filter --- data/locale/en-US.ini | 1 + src/transcription-filter.cpp | 12 +++- src/whisper-utils/silero-vad-onnx.h | 1 + src/whisper-utils/token-buffer-thread.cpp | 69 +++++++++++++++-------- src/whisper-utils/token-buffer-thread.h | 6 +- src/whisper-utils/whisper-processing.cpp | 13 +---- 6 files changed, 63 insertions(+), 39 deletions(-) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index 5507446..03efa24 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -1,6 +1,7 @@ LocalVocalPlugin="LocalVocal Plugin" transcription_filterAudioFilter="LocalVocal Transcription" vad_enabled="VAD Enabled" +vad_threshold="VAD Threshold" log_level="Internal Log Level" log_words="Log Output to Console" caption_to_stream="Stream Captions" diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 5d47df4..b32cfb0 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -388,6 +388,11 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature"); gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts"); gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty"); + + if (gf->vad_enabled && gf->vad) { + const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold"); + gf->vad->set_threshold(vad_threshold); + } } void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) @@ -524,6 +529,7 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_int(s, "buffer_num_chars_per_line", 30); obs_data_set_default_bool(s, "vad_enabled", true); + obs_data_set_default_double(s, "vad_threshold", 0.5); obs_data_set_default_int(s, "log_level", LOG_DEBUG); obs_data_set_default_bool(s, "log_words", false); obs_data_set_default_bool(s, "caption_to_stream", false); @@ -799,7 +805,7 @@ obs_properties_t *transcription_filter_properties(void *data) {"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec", "overlap_size_msec", "step_by_step_processing", "min_sub_duration", "process_while_muted", "buffered_output", "vad_enabled", "log_level", - "suppress_sentences", "sentence_psum_accept_thresh", "send_timed_metadata"}) { + "suppress_sentences", "sentence_psum_accept_thresh"}) { obs_property_set_visible(obs_properties_get(props, prop_name.c_str()), show_hide); } @@ -834,7 +840,6 @@ obs_properties_t *transcription_filter_properties(void *data) obs_properties_add_bool(ppts, "log_words", MT_("log_words")); obs_properties_add_bool(ppts, "caption_to_stream", MT_("caption_to_stream")); - obs_properties_add_bool(ppts, "send_timed_metadata", MT_("send_timed_metadata")); obs_properties_add_int_slider(ppts, "min_sub_duration", MT_("min_sub_duration"), 1000, 5000, 50); @@ -844,6 +849,9 @@ obs_properties_t *transcription_filter_properties(void *data) obs_properties_add_bool(ppts, "process_while_muted", MT_("process_while_muted")); obs_properties_add_bool(ppts, "vad_enabled", MT_("vad_enabled")); + // add vad threshold slider + obs_properties_add_float_slider(ppts, "vad_threshold", MT_("vad_threshold"), 0.0, 1.0, + 0.05); obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"), OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); diff --git a/src/whisper-utils/silero-vad-onnx.h b/src/whisper-utils/silero-vad-onnx.h index b817383..accb5af 100644 --- a/src/whisper-utils/silero-vad-onnx.h +++ b/src/whisper-utils/silero-vad-onnx.h @@ -53,6 +53,7 @@ class VadIterator { void collect_chunks(const std::vector &input_wav, std::vector &output_wav); const std::vector get_speech_timestamps() const; void drop_chunks(const std::vector &input_wav, std::vector &output_wav); + void set_threshold(float threshold) { this->threshold = threshold; } private: // model config diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index d78dfa1..3338350 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -16,16 +16,18 @@ #define NEWLINE "\n" #endif +TokenBufferThread::TokenBufferThread() noexcept + : gf(nullptr), + numSentences(1), + numPerSentence(1), + maxTime(0), + stop(true) +{ +} + TokenBufferThread::~TokenBufferThread() { - { - std::lock_guard lock(inputQueueMutex); - stop = true; - } - condVar.notify_all(); - if (workerThread.joinable()) { - workerThread.join(); - } + stopThread(); } void TokenBufferThread::initialize(struct transcription_filter_data *gf_, @@ -41,13 +43,18 @@ void TokenBufferThread::initialize(struct transcription_filter_data *gf_, this->segmentation = segmentation_; this->maxTime = maxTime_; this->stop = false; + this->presentationQueueMutex = std::make_unique(); + this->inputQueueMutex = std::make_unique(); this->workerThread = std::thread(&TokenBufferThread::monitor, this); } void TokenBufferThread::stopThread() { - std::lock_guard lock(inputQueueMutex); - stop = true; + { + std::lock_guard lock(*inputQueueMutex); + std::lock_guard lockPresentation(*presentationQueueMutex); + stop = true; + } condVar.notify_all(); if (workerThread.joinable()) { workerThread.join(); @@ -85,7 +92,7 @@ void TokenBufferThread::addSentence(const std::string &sentence) } #endif - std::lock_guard lock(inputQueueMutex); + std::lock_guard lock(*inputQueueMutex); // add the reconstructed sentence to the wordQueue for (const auto &character : characters) { @@ -97,11 +104,11 @@ void TokenBufferThread::addSentence(const std::string &sentence) void TokenBufferThread::clear() { { - std::lock_guard lock(inputQueueMutex); + std::lock_guard lock(*inputQueueMutex); inputQueue.clear(); } { - std::lock_guard lock(presentationQueueMutex); + std::lock_guard lock(*presentationQueueMutex); presentationQueue.clear(); } this->callback(""); @@ -114,8 +121,14 @@ void TokenBufferThread::monitor() this->callback(""); while (!this->stop) { + std::string caption_out; + + if (presentationQueueMutex == nullptr) { + break; + } + { - std::unique_lock lockPresentation(this->presentationQueueMutex); + std::lock_guard lockPresentation(*presentationQueueMutex); // condition presentation queue if (presentationQueue.size() == this->numSentences * this->numPerSentence) { // pop a whole sentence from the presentation queue front @@ -125,7 +138,11 @@ void TokenBufferThread::monitor() } { - std::unique_lock lock(this->inputQueueMutex); + if (inputQueueMutex == nullptr) { + break; + } + + std::lock_guard lock(*inputQueueMutex); if (!inputQueue.empty()) { // if there are token on the input queue @@ -194,22 +211,30 @@ void TokenBufferThread::monitor() int count = WideCharToMultiByte(CP_UTF8, 0, caption.c_str(), (int)caption.length(), NULL, 0, NULL, NULL); - std::string caption_out(count, 0); + caption_out = std::string(count, 0); WideCharToMultiByte(CP_UTF8, 0, caption.c_str(), (int)caption.length(), &caption_out[0], count, NULL, NULL); #else - std::string caption_out(caption.begin(), caption.end()); + caption_out = std::string(caption.begin(), caption.end()); #endif - - // emit the caption - this->callback(caption_out); } } + if (caption_out.empty()) { + // if no caption was built, sleep for a while + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + // emit the caption + this->callback(caption_out); + // check the input queue size (iqs), if it's big - sleep less - std::this_thread::sleep_for( - std::chrono::milliseconds(inputQueue.size() > 15 ? 66 : 100)); + std::this_thread::sleep_for(std::chrono::milliseconds(inputQueue.size() > 30 ? 33 + : inputQueue.size() > 15 + ? 66 + : 100)); } obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h index a318a0e..83afa45 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/whisper-utils/token-buffer-thread.h @@ -27,7 +27,7 @@ enum TokenBufferSegmentation { SEGMENTATION_WORD = 0, SEGMENTATION_TOKEN, SEGMEN class TokenBufferThread { public: // default constructor - TokenBufferThread() = default; + TokenBufferThread() noexcept; ~TokenBufferThread(); void initialize(struct transcription_filter_data *gf, @@ -51,8 +51,8 @@ class TokenBufferThread { std::deque inputQueue; std::deque presentationQueue; std::thread workerThread; - std::mutex inputQueueMutex; - std::mutex presentationQueueMutex; + std::unique_ptr inputQueueMutex; + std::unique_ptr presentationQueueMutex; std::condition_variable condVar; std::function callback; std::chrono::seconds maxTime; diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 3db6f30..9bd0837 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -316,20 +316,9 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v uint32_t num_frames_from_infos = 0; uint64_t start_timestamp = 0; uint64_t end_timestamp = 0; - size_t overlap_size = 0; //gf->sample_rate / 10; + size_t overlap_size = 0; for (size_t c = 0; c < gf->channels; c++) { - // if (!current_vad_on && gf->last_num_frames > overlap_size) { - // if (c == 0) { - // // print only once - // obs_log(gf->log_level, "VAD overlap: %lu frames", overlap_size); - // } - // // move 100ms from the end of copy_buffers to the beginning - // memmove(gf->copy_buffers[c], gf->copy_buffers[c] + gf->last_num_frames - overlap_size, - // overlap_size * sizeof(float)); - // } else { - // overlap_size = 0; - // } // zero the rest of copy_buffers memset(gf->copy_buffers[c] + overlap_size, 0, (gf->frames - overlap_size) * sizeof(float)); From 191bbe2baabb79c561e14042adea74cfd842ca7c Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 5 Jun 2024 11:03:58 -0400 Subject: [PATCH 10/12] refactor: Update VAD threshold parameter name in silero-vad-onnx.h --- src/whisper-utils/silero-vad-onnx.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/whisper-utils/silero-vad-onnx.h b/src/whisper-utils/silero-vad-onnx.h index accb5af..3d4de6c 100644 --- a/src/whisper-utils/silero-vad-onnx.h +++ b/src/whisper-utils/silero-vad-onnx.h @@ -53,7 +53,7 @@ class VadIterator { void collect_chunks(const std::vector &input_wav, std::vector &output_wav); const std::vector get_speech_timestamps() const; void drop_chunks(const std::vector &input_wav, std::vector &output_wav); - void set_threshold(float threshold) { this->threshold = threshold; } + void set_threshold(float threshold_) { this->threshold = threshold_; } private: // model config From e2e0c90fd252df1199e94d117184559982093e25 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 5 Jun 2024 17:44:11 -0400 Subject: [PATCH 11/12] refactor: Update VAD threshold parameter name in silero-vad-onnx.h --- src/transcription-filter-callbacks.cpp | 63 +++++++++++++ src/transcription-filter-callbacks.h | 6 ++ src/transcription-filter.c | 1 + src/transcription-filter.cpp | 104 ++++++---------------- src/transcription-filter.h | 1 + src/whisper-utils/token-buffer-thread.cpp | 39 ++++---- src/whisper-utils/token-buffer-thread.h | 8 +- 7 files changed, 123 insertions(+), 99 deletions(-) diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index f8d729a..34b5d97 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -222,3 +222,66 @@ void recording_state_callback(enum obs_frontend_event event, void *data) } } } + +void reset_caption_state(transcription_filter_data *gf_) +{ + if (gf_->captions_monitor.isEnabled()) { + gf_->captions_monitor.clear(); + } + send_caption_to_source(gf_->text_source_name, "", gf_); + // flush the buffer + { + std::lock_guard lock(gf_->whisper_buf_mutex); + for (size_t c = 0; c < gf_->channels; c++) { + if (gf_->input_buffers[c].data != nullptr) { + circlebuf_free(&gf_->input_buffers[c]); + } + } + if (gf_->info_buffer.data != nullptr) { + circlebuf_free(&gf_->info_buffer); + } + if (gf_->whisper_buffer.data != nullptr) { + circlebuf_free(&gf_->whisper_buffer); + } + } +} + +void media_play_callback(void *data_, calldata_t *cd) +{ + UNUSED_PARAMETER(cd); + transcription_filter_data *gf_ = static_cast(data_); + obs_log(gf_->log_level, "media_play"); + gf_->active = true; +} + +void media_started_callback(void *data_, calldata_t *cd) +{ + UNUSED_PARAMETER(cd); + transcription_filter_data *gf_ = static_cast(data_); + obs_log(gf_->log_level, "media_started"); + gf_->active = true; + reset_caption_state(gf_); +} +void media_pause_callback(void *data_, calldata_t *cd) +{ + UNUSED_PARAMETER(cd); + transcription_filter_data *gf_ = static_cast(data_); + obs_log(gf_->log_level, "media_pause"); + gf_->active = false; +} +void media_restart_callback(void *data_, calldata_t *cd) +{ + UNUSED_PARAMETER(cd); + transcription_filter_data *gf_ = static_cast(data_); + obs_log(gf_->log_level, "media_restart"); + gf_->active = true; + reset_caption_state(gf_); +} +void media_stopped_callback(void *data_, calldata_t *cd) +{ + UNUSED_PARAMETER(cd); + transcription_filter_data *gf_ = static_cast(data_); + obs_log(gf_->log_level, "media_stopped"); + gf_->active = false; + reset_caption_state(gf_); +} diff --git a/src/transcription-filter-callbacks.h b/src/transcription-filter-callbacks.h index 481af9f..a49f099 100644 --- a/src/transcription-filter-callbacks.h +++ b/src/transcription-filter-callbacks.h @@ -17,4 +17,10 @@ void set_text_callback(struct transcription_filter_data *gf, void recording_state_callback(enum obs_frontend_event event, void *data); +void media_play_callback(void *data_, calldata_t *cd); +void media_started_callback(void *data_, calldata_t *cd); +void media_pause_callback(void *data_, calldata_t *cd); +void media_restart_callback(void *data_, calldata_t *cd); +void media_stopped_callback(void *data_, calldata_t *cd); + #endif /* TRANSCRIPTION_FILTER_CALLBACKS_H */ diff --git a/src/transcription-filter.c b/src/transcription-filter.c index 505a106..6162fab 100644 --- a/src/transcription-filter.c +++ b/src/transcription-filter.c @@ -13,4 +13,5 @@ struct obs_source_info transcription_filter_info = { .activate = transcription_filter_activate, .deactivate = transcription_filter_deactivate, .filter_audio = transcription_filter_filter_audio, + .filter_remove = transcription_filter_remove, }; diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index b32cfb0..2fc25d3 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -44,86 +44,26 @@ bool add_sources_to_list(void *list_property, obs_source_t *source) return true; } -void reset_caption_state(transcription_filter_data *gf_) +void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source) { - if (gf_->captions_monitor.isEnabled()) { - gf_->captions_monitor.clear(); - } - send_caption_to_source(gf_->text_source_name, "", gf_); - // flush the buffer - { - std::lock_guard lock(gf_->whisper_buf_mutex); - for (size_t c = 0; c < gf_->channels; c++) { - if (gf_->input_buffers[c].data != nullptr) { - circlebuf_free(&gf_->input_buffers[c]); - } - } - if (gf_->info_buffer.data != nullptr) { - circlebuf_free(&gf_->info_buffer); - } - if (gf_->whisper_buffer.data != nullptr) { - circlebuf_free(&gf_->whisper_buffer); - } - } + signal_handler_t *sh = obs_source_get_signal_handler(parent_source); + signal_handler_connect(sh, "media_play", media_play_callback, gf); + signal_handler_connect(sh, "media_started", media_started_callback, gf); + signal_handler_connect(sh, "media_pause", media_pause_callback, gf); + signal_handler_connect(sh, "media_restart", media_restart_callback, gf); + signal_handler_connect(sh, "media_stopped", media_stopped_callback, gf); + gf->source_signals_set = true; } -void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source) +void disconnect_source_signals(transcription_filter_data *gf, obs_source_t *parent_source) { signal_handler_t *sh = obs_source_get_signal_handler(parent_source); - signal_handler_connect( - sh, "media_play", - [](void *data_, calldata_t *cd) { - UNUSED_PARAMETER(cd); - transcription_filter_data *gf_ = - static_cast(data_); - obs_log(gf_->log_level, "media_play"); - gf_->active = true; - }, - gf); - signal_handler_connect( - sh, "media_started", - [](void *data_, calldata_t *cd) { - UNUSED_PARAMETER(cd); - transcription_filter_data *gf_ = - static_cast(data_); - obs_log(gf_->log_level, "media_started"); - gf_->active = true; - reset_caption_state(gf_); - }, - gf); - signal_handler_connect( - sh, "media_pause", - [](void *data_, calldata_t *cd) { - UNUSED_PARAMETER(cd); - transcription_filter_data *gf_ = - static_cast(data_); - obs_log(gf_->log_level, "media_pause"); - gf_->active = false; - }, - gf); - signal_handler_connect( - sh, "media_restart", - [](void *data_, calldata_t *cd) { - UNUSED_PARAMETER(cd); - transcription_filter_data *gf_ = - static_cast(data_); - obs_log(gf_->log_level, "media_restart"); - gf_->active = true; - reset_caption_state(gf_); - }, - gf); - signal_handler_connect( - sh, "media_stopped", - [](void *data_, calldata_t *cd) { - UNUSED_PARAMETER(cd); - transcription_filter_data *gf_ = - static_cast(data_); - obs_log(gf_->log_level, "media_stopped"); - gf_->active = false; - reset_caption_state(gf_); - }, - gf); - gf->source_signals_set = true; + signal_handler_disconnect(sh, "media_play", media_play_callback, gf); + signal_handler_disconnect(sh, "media_started", media_started_callback, gf); + signal_handler_disconnect(sh, "media_pause", media_pause_callback, gf); + signal_handler_disconnect(sh, "media_restart", media_restart_callback, gf); + signal_handler_disconnect(sh, "media_stopped", media_stopped_callback, gf); + gf->source_signals_set = false; } struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_audio_data *audio) @@ -190,6 +130,16 @@ const char *transcription_filter_name(void *unused) return MT_("transcription_filterAudioFilter"); } +void transcription_filter_remove(void *data, obs_source_t *source) +{ + struct transcription_filter_data *gf = + static_cast(data); + + obs_log(gf->log_level, "filter remove"); + + disconnect_source_signals(gf, source); +} + void transcription_filter_destroy(void *data) { struct transcription_filter_data *gf = @@ -212,6 +162,10 @@ void transcription_filter_destroy(void *data) } circlebuf_free(&gf->info_buffer); + if (gf->captions_monitor.isEnabled()) { + gf->captions_monitor.stopThread(); + } + bfree(gf); } diff --git a/src/transcription-filter.h b/src/transcription-filter.h index 6d07127..922d44e 100644 --- a/src/transcription-filter.h +++ b/src/transcription-filter.h @@ -15,6 +15,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ void transcription_filter_deactivate(void *data); void transcription_filter_defaults(obs_data_t *s); obs_properties_t *transcription_filter_properties(void *data); +void transcription_filter_remove(void *data, obs_source_t *source); const char *const PLUGIN_INFO_TEMPLATE = "LocalVocal (%1) by " diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index 3338350..a8a25f4 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -21,7 +21,9 @@ TokenBufferThread::TokenBufferThread() noexcept numSentences(1), numPerSentence(1), maxTime(0), - stop(true) + stop(true), + presentationQueueMutex(), + inputQueueMutex() { } @@ -43,19 +45,16 @@ void TokenBufferThread::initialize(struct transcription_filter_data *gf_, this->segmentation = segmentation_; this->maxTime = maxTime_; this->stop = false; - this->presentationQueueMutex = std::make_unique(); - this->inputQueueMutex = std::make_unique(); this->workerThread = std::thread(&TokenBufferThread::monitor, this); } void TokenBufferThread::stopThread() { { - std::lock_guard lock(*inputQueueMutex); - std::lock_guard lockPresentation(*presentationQueueMutex); + std::lock_guard lock(presentationQueueMutex); stop = true; } - condVar.notify_all(); + cv.notify_all(); if (workerThread.joinable()) { workerThread.join(); } @@ -92,7 +91,7 @@ void TokenBufferThread::addSentence(const std::string &sentence) } #endif - std::lock_guard lock(*inputQueueMutex); + std::lock_guard lock(inputQueueMutex); // add the reconstructed sentence to the wordQueue for (const auto &character : characters) { @@ -104,11 +103,11 @@ void TokenBufferThread::addSentence(const std::string &sentence) void TokenBufferThread::clear() { { - std::lock_guard lock(*inputQueueMutex); + std::lock_guard lock(inputQueueMutex); inputQueue.clear(); } { - std::lock_guard lock(*presentationQueueMutex); + std::lock_guard lock(presentationQueueMutex); presentationQueue.clear(); } this->callback(""); @@ -120,15 +119,15 @@ void TokenBufferThread::monitor() this->callback(""); - while (!this->stop) { + while (true) { std::string caption_out; - if (presentationQueueMutex == nullptr) { - break; - } - { - std::lock_guard lockPresentation(*presentationQueueMutex); + std::lock_guard lock(presentationQueueMutex); + if (stop) { + break; + } + // condition presentation queue if (presentationQueue.size() == this->numSentences * this->numPerSentence) { // pop a whole sentence from the presentation queue front @@ -138,11 +137,7 @@ void TokenBufferThread::monitor() } { - if (inputQueueMutex == nullptr) { - break; - } - - std::lock_guard lock(*inputQueueMutex); + std::lock_guard lock(inputQueueMutex); if (!inputQueue.empty()) { // if there are token on the input queue @@ -221,6 +216,10 @@ void TokenBufferThread::monitor() } } + if (this->stop) { + break; + } + if (caption_out.empty()) { // if no caption was built, sleep for a while std::this_thread::sleep_for(std::chrono::milliseconds(100)); diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h index 83afa45..0dbe14e 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/whisper-utils/token-buffer-thread.h @@ -51,12 +51,12 @@ class TokenBufferThread { std::deque inputQueue; std::deque presentationQueue; std::thread workerThread; - std::unique_ptr inputQueueMutex; - std::unique_ptr presentationQueueMutex; - std::condition_variable condVar; + std::mutex inputQueueMutex; + std::mutex presentationQueueMutex; std::function callback; + std::condition_variable cv; std::chrono::seconds maxTime; - bool stop = true; + std::atomic stop; bool newDataAvailable = false; size_t numSentences; size_t numPerSentence; From 04c08a5c691a47762fbfe1d3fe893a749cd0d1dc Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Wed, 5 Jun 2024 17:49:57 -0400 Subject: [PATCH 12/12] refactor: Update lock_guard parameter name in TokenBufferThread --- src/whisper-utils/token-buffer-thread.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index a8a25f4..e88d76c 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -123,7 +123,7 @@ void TokenBufferThread::monitor() std::string caption_out; { - std::lock_guard lock(presentationQueueMutex); + std::lock_guard lockPresentation(presentationQueueMutex); if (stop) { break; }