diff --git a/CMakeLists.txt b/CMakeLists.txt index 7eff873..e64f45c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,11 @@ else() include(cmake/FetchOnnxruntime.cmake) endif() +include(cmake/BuildICU.cmake) +# Add ICU to the target +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ICU) +target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC ${ICU_INCLUDE_DIR}) + target_sources( ${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c @@ -114,9 +119,11 @@ target_sources( src/whisper-utils/whisper-model-utils.cpp src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp + src/whisper-utils/vad-processing.cpp src/translation/language_codes.cpp src/translation/translation.cpp src/translation/translation-utils.cpp + src/translation/translation-language-utils.cpp src/ui/filter-replace-dialog.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) @@ -137,12 +144,14 @@ if(ENABLE_TESTS) src/whisper-utils/whisper-utils.cpp src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp + src/whisper-utils/vad-processing.cpp src/translation/language_codes.cpp - src/translation/translation.cpp) + src/translation/translation.cpp + src/translation/translation-language-utils.cpp) find_libav(${CMAKE_PROJECT_NAME}-tests) - target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs) + target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs ICU) target_include_directories(${CMAKE_PROJECT_NAME}-tests PRIVATE src) # install the tests to the release/test directory diff --git a/cmake/BuildICU.cmake b/cmake/BuildICU.cmake new file mode 100644 index 0000000..a3c575d --- /dev/null +++ b/cmake/BuildICU.cmake @@ -0,0 +1,101 @@ +include(FetchContent) +include(ExternalProject) + +set(ICU_VERSION "75.1") +set(ICU_VERSION_UNDERSCORE "75_1") +set(ICU_VERSION_DASH "75-1") +set(ICU_VERSION_NO_MINOR "75") + +if(WIN32) + set(ICU_URL + "https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_DASH}/icu4c-${ICU_VERSION_UNDERSCORE}-Win64-MSVC2022.zip" + ) + set(ICU_HASH "SHA256=7ac9c0dc6ccc1ec809c7d5689b8d831c5b8f6b11ecf70fdccc55f7ae8731ac8f") + + FetchContent_Declare( + ICU_build + URL ${ICU_URL} + URL_HASH ${ICU_HASH}) + + FetchContent_MakeAvailable(ICU_build) + + # Assuming the ZIP structure, adjust paths as necessary + set(ICU_INCLUDE_DIR "${icu_build_SOURCE_DIR}/include") + set(ICU_LIBRARY_DIR "${icu_build_SOURCE_DIR}/lib64") + set(ICU_BINARY_DIR "${icu_build_SOURCE_DIR}/bin64") + + # Define the library names + set(ICU_LIBRARIES icudt icuuc icuin) + + foreach(lib ${ICU_LIBRARIES}) + # Add ICU library + find_library( + ICU_LIB_${lib} + NAMES ${lib} + PATHS ${ICU_LIBRARY_DIR} + NO_DEFAULT_PATH REQUIRED) + # find the dll + find_file( + ICU_DLL_${lib} + NAMES ${lib}${ICU_VERSION_NO_MINOR}.dll + PATHS ${ICU_BINARY_DIR} + NO_DEFAULT_PATH) + # Copy the DLLs to the output directory + install(FILES ${ICU_DLL_${lib}} DESTINATION "obs-plugins/64bit") + # add the library + add_library(ICU::${lib} SHARED IMPORTED GLOBAL) + set_target_properties(ICU::${lib} PROPERTIES IMPORTED_LOCATION "${ICU_LIB_${lib}}" IMPORTED_IMPLIB + "${ICU_LIB_${lib}}") + endforeach() +else() + set(ICU_URL + "https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_DASH}/icu4c-${ICU_VERSION_UNDERSCORE}-src.tgz" + ) + set(ICU_HASH "SHA256=cb968df3e4d2e87e8b11c49a5d01c787bd13b9545280fc6642f826527618caef") + if(APPLE) + set(ICU_PLATFORM "MacOSX") + set(TARGET_ARCH -arch\ $ENV{MACOS_ARCH}) + set(ICU_BUILD_ENV_VARS CFLAGS=${TARGET_ARCH} CXXFLAGS=${TARGET_ARCH} LDFLAGS=${TARGET_ARCH}) + else() + set(ICU_PLATFORM "Linux") + set(ICU_BUILD_ENV_VARS CFLAGS=-fPIC CXXFLAGS=-fPIC LDFLAGS=-fPIC) + endif() + + ExternalProject_Add( + ICU_build + DOWNLOAD_EXTRACT_TIMESTAMP true + GIT_REPOSITORY "https://github.com/unicode-org/icu.git" + GIT_TAG "release-${ICU_VERSION_DASH}" + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${ICU_BUILD_ENV_VARS} /icu4c/source/runConfigureICU + ${ICU_PLATFORM} --prefix= --enable-static --disable-shared + BUILD_COMMAND make -j4 + BUILD_BYPRODUCTS + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icudata${CMAKE_STATIC_LIBRARY_SUFFIX} + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icuuc${CMAKE_STATIC_LIBRARY_SUFFIX} + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icui18n${CMAKE_STATIC_LIBRARY_SUFFIX} + INSTALL_COMMAND make install + BUILD_IN_SOURCE 1) + + ExternalProject_Get_Property(ICU_build INSTALL_DIR) + + set(ICU_INCLUDE_DIR "${INSTALL_DIR}/include") + set(ICU_LIBRARY_DIR "${INSTALL_DIR}/lib") + + set(ICU_LIBRARIES icudata icuuc icui18n) + + foreach(lib ${ICU_LIBRARIES}) + add_library(ICU::${lib} STATIC IMPORTED GLOBAL) + add_dependencies(ICU::${lib} ICU_build) + set(ICU_LIBRARY "${ICU_LIBRARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${lib}${CMAKE_STATIC_LIBRARY_SUFFIX}") + set_target_properties(ICU::${lib} PROPERTIES IMPORTED_LOCATION "${ICU_LIBRARY}" INTERFACE_INCLUDE_DIRECTORIES + "${ICU_INCLUDE_DIR}") + endforeach(lib ${ICU_LIBRARIES}) +endif() + +# Create an interface target for ICU +add_library(ICU INTERFACE) +add_dependencies(ICU ICU_build) +foreach(lib ${ICU_LIBRARIES}) + target_link_libraries(ICU INTERFACE ICU::${lib}) +endforeach() +target_include_directories(ICU SYSTEM INTERFACE $) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index 0f7b661..9ef4d18 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -1,12 +1,9 @@ LocalVocalPlugin="LocalVocal Plugin" transcription_filterAudioFilter="LocalVocal Transcription" -vad_enabled="VAD Enabled" vad_threshold="VAD Threshold" log_level="Internal Log Level" log_words="Log Output to Console" caption_to_stream="Stream Captions" -step_by_step_processing="Step-by-step processing (⚠️ increased processing)" -step_size_msec="Step size (ms)" subtitle_sources="Output Destination" none_no_output="None / No output" file_output_enable="Save to File" @@ -51,7 +48,6 @@ translate="Translation" translate_add_context="Translate with context" whisper_translate="Translate to English (Whisper)" buffer_size_msec="Buffer size (ms)" -overlap_size_msec="Overlap size (ms)" suppress_sentences="Suppress sentences (each line)" translate_output="Output Destination" dtw_token_timestamps="DTW token timestamps" @@ -85,4 +81,10 @@ buffered_output_parameters="Buffered Output Configuration" file_output_info="Note: Translation output will be saved to a file in the same directory with the target language added to the name, e.g. 'output_es.srt'." partial_transcription="Enable Partial Transcription" partial_transcription_info="Partial transcription will increase processing load on your machine to transcribe content in real-time, which may impact performance." -partial_latency="Latency (ms)" \ No newline at end of file +partial_latency="Latency (ms)" +vad_mode="VAD Mode" +Active_VAD="Active VAD" +Hybrid_VAD="Hybrid VAD" +translate_only_full_sentences="Translate only full sentences" +duration_filter_threshold="Duration filter" +segment_duration="Segment duration" \ No newline at end of file diff --git a/src/tests/localvocal-offline-test.cpp b/src/tests/localvocal-offline-test.cpp index 8fec08b..ee936af 100644 --- a/src/tests/localvocal-offline-test.cpp +++ b/src/tests/localvocal-offline-test.cpp @@ -17,6 +17,7 @@ #include "transcription-filter.h" #include "transcription-utils.h" #include "whisper-utils/whisper-utils.h" +#include "whisper-utils/vad-processing.h" #include "audio-file-utils.h" #include "translation/language_codes.h" @@ -148,7 +149,7 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p // }, // 30, std::chrono::seconds(10)); - gf->vad_enabled = true; + gf->vad_mode = VAD_MODE_ACTIVE; gf->log_words = true; gf->caption_to_stream = false; gf->start_timestamp_ms = now_ms(); @@ -157,7 +158,7 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p gf->buffered_output = false; gf->target_lang = ""; - gf->translation_ctx.add_context = true; + gf->translation_ctx.add_context = 1; gf->translation_output = ""; gf->translate = false; gf->sentence_psum_accept_thresh = 0.4; diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index f5c2209..7b8208f 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -53,8 +53,8 @@ std::string send_sentence_to_translation(const std::string &sentence, struct transcription_filter_data *gf, const std::string &source_language) { - const std::string last_text = gf->last_text; - gf->last_text = sentence; + const std::string last_text = gf->last_text_for_translation; + gf->last_text_for_translation = sentence; if (gf->translate && !sentence.empty()) { obs_log(gf->log_level, "Translating text. %s -> %s", source_language.c_str(), gf->target_lang.c_str()); @@ -199,11 +199,6 @@ void set_text_callback(struct transcription_filter_data *gf, const DetectionResultWithText &resultIn) { DetectionResultWithText result = resultIn; - if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH || - result.result == DETECTION_RESULT_PARTIAL)) { - gf->last_sub_render_time = now_ms(); - gf->cleared_last_sub = false; - } std::string str_copy = result.text; @@ -233,9 +228,12 @@ void set_text_callback(struct transcription_filter_data *gf, } } + bool should_translate = + gf->translate_only_full_sentences ? result.result == DETECTION_RESULT_SPEECH : true; + // send the sentence to translation (if enabled) std::string translated_sentence = - send_sentence_to_translation(str_copy, gf, result.language); + should_translate ? send_sentence_to_translation(str_copy, gf, result.language) : ""; if (gf->translate) { if (gf->translation_output == "none") { @@ -243,10 +241,12 @@ void set_text_callback(struct transcription_filter_data *gf, str_copy = translated_sentence; } else { if (gf->buffered_output) { - if (result.result == DETECTION_RESULT_SPEECH) { - // buffered output - add the sentence to the monitor - gf->translation_monitor.addSentence(translated_sentence); - } + // buffered output - add the sentence to the monitor + gf->translation_monitor.addSentenceFromStdString( + translated_sentence, + get_time_point_from_ms(result.start_timestamp_ms), + get_time_point_from_ms(result.end_timestamp_ms), + result.result == DETECTION_RESULT_PARTIAL); } else { // non-buffered output - send the sentence to the selected source send_caption_to_source(gf->translation_output, translated_sentence, @@ -256,9 +256,10 @@ void set_text_callback(struct transcription_filter_data *gf, } if (gf->buffered_output) { - if (result.result == DETECTION_RESULT_SPEECH) { - gf->captions_monitor.addSentence(str_copy); - } + gf->captions_monitor.addSentenceFromStdString( + str_copy, get_time_point_from_ms(result.start_timestamp_ms), + get_time_point_from_ms(result.end_timestamp_ms), + result.result == DETECTION_RESULT_PARTIAL); } else { // non-buffered output - send the sentence to the selected source send_caption_to_source(gf->text_source_name, str_copy, gf); @@ -273,6 +274,21 @@ void set_text_callback(struct transcription_filter_data *gf, result.result == DETECTION_RESULT_SPEECH) { send_sentence_to_file(gf, result, str_copy, translated_sentence); } + + if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH || + result.result == DETECTION_RESULT_PARTIAL)) { + gf->last_sub_render_time = now_ms(); + gf->cleared_last_sub = false; + if (result.result == DETECTION_RESULT_SPEECH) { + // save the last subtitle if it was a full sentence + gf->last_transcription_sentence.push_back(result.text); + // remove the oldest sentence if the buffer is too long + while (gf->last_transcription_sentence.size() > + (size_t)gf->n_context_sentences) { + gf->last_transcription_sentence.pop_front(); + } + } + } }; void recording_state_callback(enum obs_frontend_event event, void *data) @@ -314,6 +330,12 @@ void reset_caption_state(transcription_filter_data *gf_) } send_caption_to_source(gf_->text_source_name, "", gf_); send_caption_to_source(gf_->translation_output, "", gf_); + // reset translation context + gf_->last_text_for_translation = ""; + gf_->last_text_translation = ""; + gf_->translation_ctx.last_input_tokens.clear(); + gf_->translation_ctx.last_translation_tokens.clear(); + gf_->last_transcription_sentence.clear(); // flush the buffer { std::lock_guard lock(gf_->whisper_buf_mutex); diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 4b16d13..e1af694 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -36,6 +36,8 @@ struct transcription_filter_data { size_t sentence_number; // Minimal subtitle duration in ms size_t min_sub_duration; + // Maximal subtitle duration in ms + size_t max_sub_duration; // Last time a subtitle was rendered uint64_t last_sub_render_time; bool cleared_last_sub; @@ -62,7 +64,7 @@ struct transcription_filter_data { float sentence_psum_accept_thresh; bool do_silence; - bool vad_enabled; + int vad_mode; int log_level = LOG_DEBUG; bool log_words; bool caption_to_stream; @@ -84,11 +86,17 @@ struct transcription_filter_data { bool initial_creation = true; bool partial_transcription = false; int partial_latency = 1000; + float duration_filter_threshold = 2.25f; + int segment_duration = 7000; // Last transcription result - std::string last_text; + std::string last_text_for_translation; std::string last_text_translation; + // Transcription context sentences + int n_context_sentences; + std::deque last_transcription_sentence; + // Text source to output the subtitles std::string text_source_name; // Callback to set the text in the output text source (subtitles) @@ -110,6 +118,7 @@ struct transcription_filter_data { struct translation_context translation_ctx; std::string translation_model_index; std::string translation_model_path_external; + bool translate_only_full_sentences; bool buffered_output = false; TokenBufferThread captions_monitor; diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 523bbf8..4a3693f 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -7,6 +7,7 @@ #include "transcription-filter.h" #include "transcription-filter-utils.h" #include "whisper-utils/whisper-language.h" +#include "whisper-utils/vad-processing.h" #include "model-utils/model-downloader-types.h" #include "translation/language_codes.h" #include "ui/filter-replace-dialog.h" @@ -212,8 +213,12 @@ void add_translation_group_properties(obs_properties_t *ppts) obs_property_t *prop_tgt = obs_properties_add_list( translation_group, "translate_target_language", MT_("target_language"), OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING); - obs_properties_add_bool(translation_group, "translate_add_context", - MT_("translate_add_context")); + + // add slider for number of context lines to add to the translation + obs_properties_add_int_slider(translation_group, "translate_add_context", + MT_("translate_add_context"), 0, 5, 1); + obs_properties_add_bool(translation_group, "translate_only_full_sentences", + MT_("translate_only_full_sentences")); // Populate the dropdown with the language codes for (const auto &language : language_codes) { @@ -290,6 +295,31 @@ void add_buffered_output_group_properties(obs_properties_t *ppts) OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); obs_property_list_add_int(buffer_type_list, "Character", SEGMENTATION_TOKEN); obs_property_list_add_int(buffer_type_list, "Word", SEGMENTATION_WORD); + obs_property_list_add_int(buffer_type_list, "Sentence", SEGMENTATION_SENTENCE); + // add callback to the segmentation selection to set default values + obs_property_set_modified_callback(buffer_type_list, [](obs_properties_t *props, + obs_property_t *property, + obs_data_t *settings) { + UNUSED_PARAMETER(property); + UNUSED_PARAMETER(props); + const int segmentation_type = (int)obs_data_get_int(settings, "buffer_output_type"); + // set default values for the number of lines and characters per line + switch (segmentation_type) { + case SEGMENTATION_TOKEN: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 30); + break; + case SEGMENTATION_WORD: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 10); + break; + case SEGMENTATION_SENTENCE: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 2); + break; + } + return true; + }); // add buffer lines parameter obs_properties_add_int_slider(buffered_output_group, "buffer_num_lines", MT_("buffer_num_lines"), 1, 5, 1); @@ -310,16 +340,29 @@ void add_advanced_group_properties(obs_properties_t *ppts, struct transcription_ obs_properties_add_int_slider(advanced_config_group, "min_sub_duration", MT_("min_sub_duration"), 1000, 5000, 50); + obs_properties_add_int_slider(advanced_config_group, "max_sub_duration", + MT_("max_sub_duration"), 1000, 5000, 50); obs_properties_add_float_slider(advanced_config_group, "sentence_psum_accept_thresh", MT_("sentence_psum_accept_thresh"), 0.0, 1.0, 0.05); obs_properties_add_bool(advanced_config_group, "process_while_muted", MT_("process_while_muted")); - obs_properties_add_bool(advanced_config_group, "vad_enabled", MT_("vad_enabled")); + // add selection for Active VAD vs Hybrid VAD + obs_property_t *vad_mode_list = + obs_properties_add_list(advanced_config_group, "vad_mode", MT_("vad_mode"), + OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); + obs_property_list_add_int(vad_mode_list, MT_("Active_VAD"), VAD_MODE_ACTIVE); + obs_property_list_add_int(vad_mode_list, MT_("Hybrid_VAD"), VAD_MODE_HYBRID); // add vad threshold slider obs_properties_add_float_slider(advanced_config_group, "vad_threshold", MT_("vad_threshold"), 0.0, 1.0, 0.05); + // add duration filter threshold slider + obs_properties_add_float_slider(advanced_config_group, "duration_filter_threshold", + MT_("duration_filter_threshold"), 0.1, 3.0, 0.05); + // add segment duration slider + obs_properties_add_int_slider(advanced_config_group, "segment_duration", + MT_("segment_duration"), 3000, 15000, 100); // add button to open filter and replace UI dialog obs_properties_add_button2( @@ -371,6 +414,10 @@ void add_whisper_params_group_properties(obs_properties_t *ppts) WHISPER_SAMPLING_BEAM_SEARCH); obs_property_list_add_int(whisper_sampling_method_list, "Greedy", WHISPER_SAMPLING_GREEDY); + // add int slider for context sentences + obs_properties_add_int_slider(whisper_params_group, "n_context_sentences", + MT_("n_context_sentences"), 0, 5, 1); + // int n_threads; obs_properties_add_int_slider(whisper_params_group, "n_threads", MT_("n_threads"), 1, 8, 1); // int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder @@ -507,3 +554,77 @@ obs_properties_t *transcription_filter_properties(void *data) UNUSED_PARAMETER(data); return ppts; } + +void transcription_filter_defaults(obs_data_t *s) +{ + obs_log(LOG_DEBUG, "filter defaults"); + + obs_data_set_default_bool(s, "buffered_output", false); + obs_data_set_default_int(s, "buffer_num_lines", 2); + obs_data_set_default_int(s, "buffer_num_chars_per_line", 30); + obs_data_set_default_int(s, "buffer_output_type", + (int)TokenBufferSegmentation::SEGMENTATION_TOKEN); + + obs_data_set_default_bool(s, "vad_mode", VAD_MODE_ACTIVE); + obs_data_set_default_double(s, "vad_threshold", 0.65); + obs_data_set_default_double(s, "duration_filter_threshold", 2.25); + obs_data_set_default_int(s, "segment_duration", 7000); + obs_data_set_default_int(s, "log_level", LOG_DEBUG); + obs_data_set_default_bool(s, "log_words", false); + obs_data_set_default_bool(s, "caption_to_stream", false); + obs_data_set_default_string(s, "whisper_model_path", "Whisper Tiny English (74Mb)"); + obs_data_set_default_string(s, "whisper_language_select", "en"); + obs_data_set_default_string(s, "subtitle_sources", "none"); + obs_data_set_default_bool(s, "process_while_muted", false); + obs_data_set_default_bool(s, "subtitle_save_srt", false); + obs_data_set_default_bool(s, "truncate_output_file", false); + obs_data_set_default_bool(s, "only_while_recording", false); + obs_data_set_default_bool(s, "rename_file_to_match_recording", true); + obs_data_set_default_int(s, "min_sub_duration", 1000); + obs_data_set_default_int(s, "max_sub_duration", 3000); + obs_data_set_default_bool(s, "advanced_settings", false); + obs_data_set_default_bool(s, "translate", false); + obs_data_set_default_string(s, "translate_target_language", "__es__"); + obs_data_set_default_int(s, "translate_add_context", 1); + obs_data_set_default_bool(s, "translate_only_full_sentences", true); + obs_data_set_default_string(s, "translate_model", "whisper-based-translation"); + obs_data_set_default_string(s, "translation_model_path_external", ""); + obs_data_set_default_int(s, "translate_input_tokenization_style", INPUT_TOKENIZAION_M2M100); + obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4); + obs_data_set_default_bool(s, "partial_group", false); + obs_data_set_default_int(s, "partial_latency", 1100); + + // translation options + obs_data_set_default_double(s, "translation_sampling_temperature", 0.1); + obs_data_set_default_double(s, "translation_repetition_penalty", 2.0); + obs_data_set_default_int(s, "translation_beam_size", 1); + obs_data_set_default_int(s, "translation_max_decoding_length", 65); + obs_data_set_default_int(s, "translation_no_repeat_ngram_size", 1); + obs_data_set_default_int(s, "translation_max_input_length", 65); + + // Whisper parameters + obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); + obs_data_set_default_int(s, "n_context_sentences", 0); + obs_data_set_default_string(s, "initial_prompt", ""); + obs_data_set_default_int(s, "n_threads", 4); + obs_data_set_default_int(s, "n_max_text_ctx", 16384); + obs_data_set_default_bool(s, "whisper_translate", false); + obs_data_set_default_bool(s, "no_context", true); + obs_data_set_default_bool(s, "single_segment", true); + obs_data_set_default_bool(s, "print_special", false); + obs_data_set_default_bool(s, "print_progress", false); + obs_data_set_default_bool(s, "print_realtime", false); + obs_data_set_default_bool(s, "print_timestamps", false); + obs_data_set_default_bool(s, "token_timestamps", false); + obs_data_set_default_bool(s, "dtw_token_timestamps", false); + obs_data_set_default_double(s, "thold_pt", 0.01); + obs_data_set_default_double(s, "thold_ptsum", 0.01); + obs_data_set_default_int(s, "max_len", 0); + obs_data_set_default_bool(s, "split_on_word", true); + obs_data_set_default_int(s, "max_tokens", 0); + obs_data_set_default_bool(s, "suppress_blank", false); + obs_data_set_default_bool(s, "suppress_non_speech_tokens", true); + obs_data_set_default_double(s, "temperature", 0.1); + obs_data_set_default_double(s, "max_initial_ts", 1.0); + obs_data_set_default_double(s, "length_penalty", -1.0); +} diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 3683c18..657fea6 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -174,7 +174,7 @@ void transcription_filter_update(void *data, obs_data_t *s) obs_log(gf->log_level, "LocalVocal filter update"); gf->log_level = (int)obs_data_get_int(s, "log_level"); - gf->vad_enabled = obs_data_get_bool(s, "vad_enabled"); + gf->vad_mode = (int)obs_data_get_int(s, "vad_mode"); gf->log_words = obs_data_get_bool(s, "log_words"); gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream"); gf->save_to_file = obs_data_get_bool(s, "file_output_enable"); @@ -187,7 +187,10 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->sentence_number = 1; gf->process_while_muted = obs_data_get_bool(s, "process_while_muted"); gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration"); + gf->max_sub_duration = (int)obs_data_get_int(s, "max_sub_duration"); gf->last_sub_render_time = now_ms(); + gf->duration_filter_threshold = (float)obs_data_get_double(s, "duration_filter_threshold"); + gf->segment_duration = (int)obs_data_get_int(s, "segment_duration"); gf->partial_transcription = obs_data_get_bool(s, "partial_group"); gf->partial_latency = (int)obs_data_get_int(s, "partial_latency"); bool new_buffered_output = obs_data_get_bool(s, "buffered_output"); @@ -281,9 +284,10 @@ void transcription_filter_update(void *data, obs_data_t *s) bool new_translate = obs_data_get_bool(s, "translate"); gf->target_lang = obs_data_get_string(s, "translate_target_language"); - gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context"); + gf->translation_ctx.add_context = (int)obs_data_get_int(s, "translate_add_context"); gf->translation_ctx.input_tokenization_style = (InputTokenizationStyle)obs_data_get_int(s, "translate_input_tokenization_style"); + gf->translate_only_full_sentences = obs_data_get_bool(s, "translate_only_full_sentences"); gf->translation_output = obs_data_get_string(s, "translate_output"); std::string new_translate_model_index = obs_data_get_string(s, "translate_model"); std::string new_translation_model_path_external = @@ -342,6 +346,8 @@ void transcription_filter_update(void *data, obs_data_t *s) { std::lock_guard lock(gf->whisper_ctx_mutex); + gf->n_context_sentences = (int)obs_data_get_int(s, "n_context_sentences"); + gf->sentence_psum_accept_thresh = (float)obs_data_get_double(s, "sentence_psum_accept_thresh"); @@ -390,7 +396,7 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts"); gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty"); - if (gf->vad_enabled && gf->vad) { + if (gf->vad) { const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold"); gf->vad->set_threshold(vad_threshold); } @@ -431,6 +437,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / MAX_MS_WORK_BUFFER)); gf->last_num_frames = 0; gf->min_sub_duration = (int)obs_data_get_int(settings, "min_sub_duration"); + gf->max_sub_duration = (int)obs_data_get_int(settings, "max_sub_duration"); gf->last_sub_render_time = now_ms(); gf->log_level = (int)obs_data_get_int(settings, "log_level"); gf->save_srt = obs_data_get_bool(settings, "subtitle_save_srt"); @@ -551,72 +558,3 @@ void transcription_filter_hide(void *data) static_cast(data); obs_log(gf->log_level, "filter hide"); } - -void transcription_filter_defaults(obs_data_t *s) -{ - obs_log(LOG_DEBUG, "filter defaults"); - - obs_data_set_default_bool(s, "buffered_output", false); - obs_data_set_default_int(s, "buffer_num_lines", 2); - obs_data_set_default_int(s, "buffer_num_chars_per_line", 30); - obs_data_set_default_int(s, "buffer_output_type", - (int)TokenBufferSegmentation::SEGMENTATION_TOKEN); - - obs_data_set_default_bool(s, "vad_enabled", true); - obs_data_set_default_double(s, "vad_threshold", 0.65); - obs_data_set_default_int(s, "log_level", LOG_DEBUG); - obs_data_set_default_bool(s, "log_words", false); - obs_data_set_default_bool(s, "caption_to_stream", false); - obs_data_set_default_string(s, "whisper_model_path", "Whisper Tiny English (74Mb)"); - obs_data_set_default_string(s, "whisper_language_select", "en"); - obs_data_set_default_string(s, "subtitle_sources", "none"); - obs_data_set_default_bool(s, "process_while_muted", false); - obs_data_set_default_bool(s, "subtitle_save_srt", false); - obs_data_set_default_bool(s, "truncate_output_file", false); - obs_data_set_default_bool(s, "only_while_recording", false); - obs_data_set_default_bool(s, "rename_file_to_match_recording", true); - obs_data_set_default_int(s, "min_sub_duration", 3000); - obs_data_set_default_bool(s, "advanced_settings", false); - obs_data_set_default_bool(s, "translate", false); - obs_data_set_default_string(s, "translate_target_language", "__es__"); - obs_data_set_default_bool(s, "translate_add_context", true); - obs_data_set_default_string(s, "translate_model", "whisper-based-translation"); - obs_data_set_default_string(s, "translation_model_path_external", ""); - obs_data_set_default_int(s, "translate_input_tokenization_style", INPUT_TOKENIZAION_M2M100); - obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4); - obs_data_set_default_bool(s, "partial_group", false); - obs_data_set_default_int(s, "partial_latency", 1100); - - // translation options - obs_data_set_default_double(s, "translation_sampling_temperature", 0.1); - obs_data_set_default_double(s, "translation_repetition_penalty", 2.0); - obs_data_set_default_int(s, "translation_beam_size", 1); - obs_data_set_default_int(s, "translation_max_decoding_length", 65); - obs_data_set_default_int(s, "translation_no_repeat_ngram_size", 1); - obs_data_set_default_int(s, "translation_max_input_length", 65); - - // Whisper parameters - obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); - obs_data_set_default_string(s, "initial_prompt", ""); - obs_data_set_default_int(s, "n_threads", 4); - obs_data_set_default_int(s, "n_max_text_ctx", 16384); - obs_data_set_default_bool(s, "whisper_translate", false); - obs_data_set_default_bool(s, "no_context", true); - obs_data_set_default_bool(s, "single_segment", true); - obs_data_set_default_bool(s, "print_special", false); - obs_data_set_default_bool(s, "print_progress", false); - obs_data_set_default_bool(s, "print_realtime", false); - obs_data_set_default_bool(s, "print_timestamps", false); - obs_data_set_default_bool(s, "token_timestamps", false); - obs_data_set_default_bool(s, "dtw_token_timestamps", false); - obs_data_set_default_double(s, "thold_pt", 0.01); - obs_data_set_default_double(s, "thold_ptsum", 0.01); - obs_data_set_default_int(s, "max_len", 0); - obs_data_set_default_bool(s, "split_on_word", true); - obs_data_set_default_int(s, "max_tokens", 0); - obs_data_set_default_bool(s, "suppress_blank", false); - obs_data_set_default_bool(s, "suppress_non_speech_tokens", true); - obs_data_set_default_double(s, "temperature", 0.1); - obs_data_set_default_double(s, "max_initial_ts", 1.0); - obs_data_set_default_double(s, "length_penalty", -1.0); -} diff --git a/src/translation/translation-language-utils.cpp b/src/translation/translation-language-utils.cpp new file mode 100644 index 0000000..685ca1a --- /dev/null +++ b/src/translation/translation-language-utils.cpp @@ -0,0 +1,33 @@ +#include "translation-language-utils.h" + +#include +#include + +std::string remove_start_punctuation(const std::string &text) +{ + if (text.empty()) { + return text; + } + + // Convert the input string to ICU's UnicodeString + icu::UnicodeString ustr = icu::UnicodeString::fromUTF8(text); + + // Find the index of the first non-punctuation character + int32_t start = 0; + while (start < ustr.length()) { + UChar32 ch = ustr.char32At(start); + if (!u_ispunct(ch)) { + break; + } + start += U16_LENGTH(ch); + } + + // Create a new UnicodeString with punctuation removed from the start + icu::UnicodeString result = ustr.tempSubString(start); + + // Convert the result back to UTF-8 + std::string output; + result.toUTF8String(output); + + return output; +} diff --git a/src/translation/translation-language-utils.h b/src/translation/translation-language-utils.h new file mode 100644 index 0000000..44b450a --- /dev/null +++ b/src/translation/translation-language-utils.h @@ -0,0 +1,8 @@ +#ifndef TRANSLATION_LANGUAGE_UTILS_H +#define TRANSLATION_LANGUAGE_UTILS_H + +#include + +std::string remove_start_punctuation(const std::string &text); + +#endif // TRANSLATION_LANGUAGE_UTILS_H \ No newline at end of file diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp index e11f072..0701d95 100644 --- a/src/translation/translation.cpp +++ b/src/translation/translation.cpp @@ -3,6 +3,7 @@ #include "model-utils/model-find-utils.h" #include "transcription-filter-data.h" #include "language_codes.h" +#include "translation-language-utils.h" #include #include @@ -114,31 +115,53 @@ int translate(struct translation_context &translation_ctx, const std::string &te if (translation_ctx.input_tokenization_style == INPUT_TOKENIZAION_M2M100) { // set input tokens std::vector input_tokens = {source_lang, ""}; - if (translation_ctx.add_context && + if (translation_ctx.add_context > 0 && translation_ctx.last_input_tokens.size() > 0) { - input_tokens.insert(input_tokens.end(), - translation_ctx.last_input_tokens.begin(), - translation_ctx.last_input_tokens.end()); + // add the last input tokens sentences to the input tokens + for (const auto &tokens : translation_ctx.last_input_tokens) { + input_tokens.insert(input_tokens.end(), tokens.begin(), + tokens.end()); + } } std::vector new_input_tokens = translation_ctx.tokenizer(text); input_tokens.insert(input_tokens.end(), new_input_tokens.begin(), new_input_tokens.end()); input_tokens.push_back(""); - translation_ctx.last_input_tokens = new_input_tokens; + // log the input tokens + std::string input_tokens_str; + for (const auto &token : input_tokens) { + input_tokens_str += token + ", "; + } + obs_log(LOG_INFO, "Input tokens: %s", input_tokens_str.c_str()); + + translation_ctx.last_input_tokens.push_back(new_input_tokens); + // remove the oldest input tokens + while (translation_ctx.last_input_tokens.size() > + (size_t)translation_ctx.add_context) { + translation_ctx.last_input_tokens.pop_front(); + } const std::vector> batch = {input_tokens}; // get target prefix target_prefix = {target_lang}; - if (translation_ctx.add_context && + // add the last translation tokens to the target prefix + if (translation_ctx.add_context > 0 && translation_ctx.last_translation_tokens.size() > 0) { - target_prefix.insert( - target_prefix.end(), - translation_ctx.last_translation_tokens.begin(), - translation_ctx.last_translation_tokens.end()); + for (const auto &tokens : translation_ctx.last_translation_tokens) { + target_prefix.insert(target_prefix.end(), tokens.begin(), + tokens.end()); + } } + // log the target prefix + std::string target_prefix_str; + for (const auto &token : target_prefix) { + target_prefix_str += token + ","; + } + obs_log(LOG_INFO, "Target prefix: %s", target_prefix_str.c_str()); + const std::vector> target_prefix_batch = { target_prefix}; results = translation_ctx.translator->translate_batch( @@ -161,9 +184,26 @@ int translate(struct translation_context &translation_ctx, const std::string &te std::vector translation_tokens( tokens_result.begin() + target_prefix.size(), tokens_result.end()); - translation_ctx.last_translation_tokens = translation_tokens; + // log the translation tokens + std::string translation_tokens_str; + for (const auto &token : translation_tokens) { + translation_tokens_str += token + ", "; + } + obs_log(LOG_INFO, "Translation tokens: %s", translation_tokens_str.c_str()); + + // save the translation tokens + translation_ctx.last_translation_tokens.push_back(translation_tokens); + // remove the oldest translation tokens + while (translation_ctx.last_translation_tokens.size() > + (size_t)translation_ctx.add_context) { + translation_ctx.last_translation_tokens.pop_front(); + } + obs_log(LOG_INFO, "Last translation tokens deque size: %d", + (int)translation_ctx.last_translation_tokens.size()); + // detokenize - result = translation_ctx.detokenizer(translation_tokens); + const std::string result_ = translation_ctx.detokenizer(translation_tokens); + result = remove_start_punctuation(result_); } catch (std::exception &e) { obs_log(LOG_ERROR, "Error: %s", e.what()); return OBS_POLYGLOT_TRANSLATION_FAIL; diff --git a/src/translation/translation.h b/src/translation/translation.h index 0d45080..c740726 100644 --- a/src/translation/translation.h +++ b/src/translation/translation.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -25,10 +26,10 @@ struct translation_context { std::unique_ptr options; std::function(const std::string &)> tokenizer; std::function &)> detokenizer; - std::vector last_input_tokens; - std::vector last_translation_tokens; - // Use the last translation as context for the next translation - bool add_context; + std::deque> last_input_tokens; + std::deque> last_translation_tokens; + // How many sentences to use as context for the next translation + int add_context; InputTokenizationStyle input_tokenization_style; }; diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index ac34534..3e3b002 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -6,6 +6,9 @@ #include "whisper-utils.h" #include "transcription-utils.h" +#include +#include + #include #ifdef _WIN32 @@ -75,37 +78,74 @@ void TokenBufferThread::log_token_vector(const std::vector &tokens) obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); } -void TokenBufferThread::addSentence(const std::string &sentence) +void TokenBufferThread::addSentenceFromStdString(const std::string &sentence, + TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial) { + if (sentence.empty()) { + return; + } #ifdef _WIN32 // on windows convert from multibyte to wide char int count = MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), NULL, 0); - std::wstring sentence_ws(count, 0); + TokenBufferString sentence_ws(count, 0); MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), &sentence_ws[0], count); #else - std::string sentence_ws = sentence; + TokenBufferString sentence_ws = sentence; #endif - // split to characters - std::vector characters; - for (const auto &c : sentence_ws) { - characters.push_back(TokenBufferString(1, c)); + + TokenBufferSentence sentence_for_add; + sentence_for_add.start_time = start_time; + sentence_for_add.end_time = end_time; + + if (this->segmentation == SEGMENTATION_WORD) { + // split the sentence to words + std::vector words; + std::basic_istringstream iss(sentence_ws); + TokenBufferString word_token; + while (iss >> word_token) { + words.push_back(word_token); + } + // add the words to a sentence + for (const auto &word : words) { + sentence_for_add.tokens.push_back({word, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); + } + } else if (this->segmentation == SEGMENTATION_TOKEN) { + // split to characters + std::vector characters; + for (const auto &c : sentence_ws) { + characters.push_back(TokenBufferString(1, c)); + } + // add the characters to a sentece + for (const auto &character : characters) { + sentence_for_add.tokens.push_back({character, is_partial}); + } + } else { + // add the whole sentence as a single token + sentence_for_add.tokens.push_back({sentence_ws, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); } + addSentence(sentence_for_add); +} - std::lock_guard lock(inputQueueMutex); +void TokenBufferThread::addSentence(const TokenBufferSentence &sentence) +{ + std::lock_guard lock(this->inputQueueMutex); - // add the characters to the inputQueue - for (const auto &character : characters) { + // add the tokens to the inputQueue + for (const auto &character : sentence.tokens) { inputQueue.push_back(character); } - inputQueue.push_back(SPACE); + inputQueue.push_back({SPACE, sentence.tokens.back().is_partial}); // add to the contribution queue as well - for (const auto &character : characters) { + for (const auto &character : sentence.tokens) { contributionQueue.push_back(character); } - contributionQueue.push_back(SPACE); + contributionQueue.push_back({SPACE, sentence.tokens.back().is_partial}); this->lastContributionTime = std::chrono::steady_clock::now(); } @@ -148,7 +188,7 @@ void TokenBufferThread::monitor() if (this->segmentation == SEGMENTATION_TOKEN) { // pop tokens until a space is found while (!presentationQueue.empty() && - presentationQueue.front() != SPACE) { + presentationQueue.front().token != SPACE) { presentationQueue.pop_front(); } } @@ -158,6 +198,13 @@ void TokenBufferThread::monitor() std::lock_guard lock(inputQueueMutex); if (!inputQueue.empty()) { + // if the input on the inputQueue is partial - first remove all partials + // from the end of the presentation queue + while (!presentationQueue.empty() && + presentationQueue.back().is_partial) { + presentationQueue.pop_back(); + } + // if there are token on the input queue // then add to the presentation queue based on the segmentation if (this->segmentation == SEGMENTATION_SENTENCE) { @@ -171,16 +218,17 @@ void TokenBufferThread::monitor() presentationQueue.push_back(inputQueue.front()); inputQueue.pop_front(); } else { + // SEGMENTATION_WORD // skip spaces in the beginning of the input queue while (!inputQueue.empty() && - inputQueue.front() == SPACE) { + inputQueue.front().token == SPACE) { inputQueue.pop_front(); } // add one word to the presentation queue - TokenBufferString word; + TokenBufferToken word; while (!inputQueue.empty() && - inputQueue.front() != SPACE) { - word += inputQueue.front(); + inputQueue.front().token != SPACE) { + word = inputQueue.front(); inputQueue.pop_front(); } presentationQueue.push_back(word); @@ -200,7 +248,7 @@ void TokenBufferThread::monitor() size_t wordsInSentence = 0; for (size_t i = 0; i < presentationQueue.size(); i++) { const auto &word = presentationQueue[i]; - sentences.back() += word + SPACE; + sentences.back() += word.token + SPACE; wordsInSentence++; if (wordsInSentence == this->numPerSentence) { sentences.push_back(TokenBufferString()); @@ -211,12 +259,12 @@ void TokenBufferThread::monitor() for (size_t i = 0; i < presentationQueue.size(); i++) { const auto &token = presentationQueue[i]; // skip spaces in the beginning of a sentence (tokensInSentence == 0) - if (token == SPACE && + if (token.token == SPACE && sentences.back().length() == 0) { continue; } - sentences.back() += token; + sentences.back() += token.token; if (sentences.back().length() == this->numPerSentence) { // if the next character is not a space - this is a broken word @@ -280,7 +328,7 @@ void TokenBufferThread::monitor() // take the contribution queue and send it to the output TokenBufferString contribution; for (const auto &token : contributionQueue) { - contribution += token; + contribution += token.token; } contributionQueue.clear(); #ifdef _WIN32 diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h index 13be208..7666669 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/whisper-utils/token-buffer-thread.h @@ -16,8 +16,10 @@ #ifdef _WIN32 typedef std::wstring TokenBufferString; +typedef wchar_t TokenBufferChar; #else typedef std::string TokenBufferString; +typedef char TokenBufferChar; #endif struct transcription_filter_data; @@ -27,6 +29,22 @@ enum TokenBufferSpeed { SPEED_SLOW = 0, SPEED_NORMAL, SPEED_FAST }; typedef std::chrono::time_point TokenBufferTimePoint; +inline std::chrono::time_point get_time_point_from_ms(uint64_t ms) +{ + return std::chrono::time_point(std::chrono::milliseconds(ms)); +} + +struct TokenBufferToken { + TokenBufferString token; + bool is_partial; +}; + +struct TokenBufferSentence { + std::vector tokens; + TokenBufferTimePoint start_time; + TokenBufferTimePoint end_time; +}; + class TokenBufferThread { public: // default constructor @@ -40,7 +58,9 @@ class TokenBufferThread { std::chrono::seconds maxTime_, TokenBufferSegmentation segmentation_ = SEGMENTATION_TOKEN); - void addSentence(const std::string &sentence); + void addSentenceFromStdString(const std::string &sentence, TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial = false); + void addSentence(const TokenBufferSentence &sentence); void clear(); void stopThread(); @@ -59,9 +79,9 @@ class TokenBufferThread { void log_token_vector(const std::vector &tokens); int getWaitTime(TokenBufferSpeed speed) const; struct transcription_filter_data *gf; - std::deque inputQueue; - std::deque presentationQueue; - std::deque contributionQueue; + std::deque inputQueue; + std::deque presentationQueue; + std::deque contributionQueue; std::thread workerThread; std::mutex inputQueueMutex; std::mutex presentationQueueMutex; diff --git a/src/whisper-utils/vad-processing.cpp b/src/whisper-utils/vad-processing.cpp new file mode 100644 index 0000000..0e9c744 --- /dev/null +++ b/src/whisper-utils/vad-processing.cpp @@ -0,0 +1,377 @@ + +#include + +#include "transcription-filter-data.h" + +#include "vad-processing.h" + +#ifdef _WIN32 +#define NOMINMAX +#include +#endif + +int get_data_from_buf_and_resample(transcription_filter_data *gf, + uint64_t &start_timestamp_offset_ns, + uint64_t &end_timestamp_offset_ns) +{ + uint32_t num_frames_from_infos = 0; + + { + // scoped lock the buffer mutex + std::lock_guard lock(gf->whisper_buf_mutex); + + if (gf->input_buffers[0].size == 0) { + return 1; + } + + obs_log(gf->log_level, + "segmentation: currently %lu bytes in the audio input buffer", + gf->input_buffers[0].size); + + // max number of frames is 10 seconds worth of audio + const size_t max_num_frames = gf->sample_rate * 10; + + // pop all infos from the info buffer and mark the beginning timestamp from the first + // info as the beginning timestamp of the segment + struct transcription_filter_audio_info info_from_buf = {0}; + const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); + while (gf->info_buffer.size >= size_of_audio_info) { + circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); + num_frames_from_infos += info_from_buf.frames; + if (start_timestamp_offset_ns == 0) { + start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; + } + // Check if we're within the needed segment length + if (num_frames_from_infos > max_num_frames) { + // too big, push the last info into the buffer's front where it was + num_frames_from_infos -= info_from_buf.frames; + circlebuf_push_front(&gf->info_buffer, &info_from_buf, + size_of_audio_info); + break; + } + } + // calculate the end timestamp from the info plus the number of frames in the packet + end_timestamp_offset_ns = info_from_buf.timestamp_offset_ns + + info_from_buf.frames * 1000000000 / gf->sample_rate; + + if (start_timestamp_offset_ns > end_timestamp_offset_ns) { + // this may happen when the incoming media has a timestamp reset + // in this case, we should figure out the start timestamp from the end timestamp + // and the number of frames + start_timestamp_offset_ns = + end_timestamp_offset_ns - + num_frames_from_infos * 1000000000 / gf->sample_rate; + } + + for (size_t c = 0; c < gf->channels; c++) { + // zero the rest of copy_buffers + memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); + } + + /* Pop from input circlebuf */ + for (size_t c = 0; c < gf->channels; c++) { + // Push the new data to copy_buffers[c] + circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c], + num_frames_from_infos * sizeof(float)); + } + } + + obs_log(gf->log_level, "found %d frames from info buffer.", num_frames_from_infos); + gf->last_num_frames = num_frames_from_infos; + + { + // resample to 16kHz + float *resampled_16khz[MAX_PREPROC_CHANNELS]; + uint32_t resampled_16khz_frames; + uint64_t ts_offset; + { + ProfileScope("resample"); + audio_resampler_resample(gf->resampler_to_whisper, + (uint8_t **)resampled_16khz, + &resampled_16khz_frames, &ts_offset, + (const uint8_t **)gf->copy_buffers, + (uint32_t)num_frames_from_infos); + } + + circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], + resampled_16khz_frames * sizeof(float)); + obs_log(gf->log_level, + "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", + (int)gf->channels, (int)resampled_16khz_frames, + (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, + gf->resampled_buffer.size); + } + + return 0; +} + +vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +{ + // get data from buffer and resample + uint64_t start_timestamp_offset_ns = 0; + uint64_t end_timestamp_offset_ns = 0; + + const int ret = get_data_from_buf_and_resample(gf, start_timestamp_offset_ns, + end_timestamp_offset_ns); + if (ret != 0) { + return last_vad_state; + } + + const size_t vad_window_size_samples = gf->vad->get_window_size_samples() * sizeof(float); + const size_t min_vad_buffer_size = vad_window_size_samples * 8; + if (gf->resampled_buffer.size < min_vad_buffer_size) + return last_vad_state; + + size_t vad_num_windows = gf->resampled_buffer.size / vad_window_size_samples; + + std::vector vad_input; + vad_input.resize(vad_num_windows * gf->vad->get_window_size_samples()); + circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), + vad_input.size() * sizeof(float)); + + obs_log(gf->log_level, "sending %d frames to vad, %d windows, reset state? %s", + vad_input.size(), vad_num_windows, (!last_vad_state.vad_on) ? "yes" : "no"); + { + ProfileScope("vad->process"); + gf->vad->process(vad_input, !last_vad_state.vad_on); + } + + const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000; + const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000; + + vad_state current_vad_state = {false, start_ts_offset_ms, end_ts_offset_ms, + last_vad_state.last_partial_segment_end_ts}; + + std::vector stamps = gf->vad->get_speech_timestamps(); + if (stamps.size() == 0) { + obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); + if (last_vad_state.vad_on) { + obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference"); + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, + VAD_STATE_WAS_ON); + current_vad_state.last_partial_segment_end_ts = 0; + } + + if (gf->enable_audio_chunks_callback) { + audio_chunk_callback(gf, vad_input.data(), vad_input.size(), + VAD_STATE_IS_OFF, + {DETECTION_RESULT_SILENCE, + "[silence]", + current_vad_state.start_ts_offest_ms, + current_vad_state.end_ts_offset_ms, + {}}); + } + + return current_vad_state; + } + + // process vad segments + for (size_t i = 0; i < stamps.size(); i++) { + int start_frame = stamps[i].start; + if (i > 0) { + // if this is not the first segment, start from the end of the previous segment + start_frame = stamps[i - 1].end; + } else { + // take at least 100ms of audio before the first speech segment, if available + start_frame = std::max(0, start_frame - WHISPER_SAMPLE_RATE / 10); + } + + int end_frame = stamps[i].end; + // if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { + // // take at least 100ms of audio after the last speech segment, if available + // end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, + // (int)vad_input.size()); + // } + + const int number_of_frames = end_frame - start_frame; + + // push the data into gf-whisper_buffer + circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, + number_of_frames * sizeof(float)); + + obs_log(gf->log_level, + "VAD segment %d/%d. pushed %d to %d (%d frames / %lu ms). current size: %lu bytes / %lu frames / %lu ms", + i, (stamps.size() - 1), start_frame, end_frame, number_of_frames, + number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size, + gf->whisper_buffer.size / sizeof(float), + gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE); + + // segment "end" is in the middle of the buffer, send it to inference + if (stamps[i].end < (int)vad_input.size()) { + // new "ending" segment (not up to the end of the buffer) + obs_log(gf->log_level, "VAD segment end -> send to inference"); + // find the end timestamp of the segment + const uint64_t segment_end_ts = + start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; + run_inference_and_callbacks( + gf, last_vad_state.start_ts_offest_ms, segment_end_ts, + last_vad_state.vad_on ? VAD_STATE_WAS_ON : VAD_STATE_WAS_OFF); + current_vad_state.vad_on = false; + current_vad_state.start_ts_offest_ms = current_vad_state.end_ts_offset_ms; + current_vad_state.end_ts_offset_ms = 0; + current_vad_state.last_partial_segment_end_ts = 0; + last_vad_state = current_vad_state; + continue; + } + + // end not reached - speech is ongoing + current_vad_state.vad_on = true; + if (last_vad_state.vad_on) { + obs_log(gf->log_level, + "last vad state was: ON, start ts: %llu, end ts: %llu", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms); + current_vad_state.start_ts_offest_ms = last_vad_state.start_ts_offest_ms; + } else { + obs_log(gf->log_level, + "last vad state was: OFF, start ts: %llu, end ts: %llu. start_ts_offset_ms: %llu, start_frame: %d", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms, + start_ts_offset_ms, start_frame); + current_vad_state.start_ts_offest_ms = + start_ts_offset_ms + start_frame * 1000 / WHISPER_SAMPLE_RATE; + } + current_vad_state.end_ts_offset_ms = + start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; + obs_log(gf->log_level, + "end not reached. vad state: ON, start ts: %llu, end ts: %llu", + current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms); + + last_vad_state = current_vad_state; + + // if partial transcription is enabled, check if we should send a partial segment + if (!gf->partial_transcription) { + continue; + } + + // current length of audio in buffer + const uint64_t current_length_ms = + (current_vad_state.end_ts_offset_ms > 0 + ? current_vad_state.end_ts_offset_ms + : current_vad_state.start_ts_offest_ms) - + (current_vad_state.last_partial_segment_end_ts > 0 + ? current_vad_state.last_partial_segment_end_ts + : current_vad_state.start_ts_offest_ms); + obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + current_vad_state.last_partial_segment_end_ts, current_length_ms); + + if (current_length_ms > (uint64_t)gf->partial_latency) { + current_vad_state.last_partial_segment_end_ts = + current_vad_state.end_ts_offset_ms; + // send partial segment to inference + obs_log(gf->log_level, "Partial segment -> send to inference"); + run_inference_and_callbacks(gf, current_vad_state.start_ts_offest_ms, + current_vad_state.end_ts_offset_ms, + VAD_STATE_PARTIAL); + } + } + + return current_vad_state; +} + +vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +{ + // get data from buffer and resample + uint64_t start_timestamp_offset_ns = 0; + uint64_t end_timestamp_offset_ns = 0; + + if (get_data_from_buf_and_resample(gf, start_timestamp_offset_ns, + end_timestamp_offset_ns) != 0) { + return last_vad_state; + } + + last_vad_state.end_ts_offset_ms = end_timestamp_offset_ns / 1000000; + + // extract the data from the resampled buffer with circlebuf_pop_front into a temp buffer + // and then push it into the whisper buffer + const size_t resampled_buffer_size = gf->resampled_buffer.size; + std::vector temp_buffer; + temp_buffer.resize(resampled_buffer_size); + circlebuf_pop_front(&gf->resampled_buffer, temp_buffer.data(), resampled_buffer_size); + circlebuf_push_back(&gf->whisper_buffer, temp_buffer.data(), resampled_buffer_size); + + obs_log(gf->log_level, "whisper buffer size: %lu bytes", gf->whisper_buffer.size); + + // use last_vad_state timestamps to calculate the duration of the current segment + if (last_vad_state.end_ts_offset_ms - last_vad_state.start_ts_offest_ms >= + (uint64_t)gf->segment_duration) { + obs_log(gf->log_level, "%d seconds worth of audio -> send to inference", + gf->segment_duration); + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, VAD_STATE_WAS_ON); + last_vad_state.start_ts_offest_ms = end_timestamp_offset_ns / 1000000; + last_vad_state.last_partial_segment_end_ts = 0; + return last_vad_state; + } + + // if partial transcription is enabled, check if we should send a partial segment + if (gf->partial_transcription) { + // current length of audio in buffer + const uint64_t current_length_ms = + (last_vad_state.end_ts_offset_ms > 0 ? last_vad_state.end_ts_offset_ms + : last_vad_state.start_ts_offest_ms) - + (last_vad_state.last_partial_segment_end_ts > 0 + ? last_vad_state.last_partial_segment_end_ts + : last_vad_state.start_ts_offest_ms); + obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + last_vad_state.last_partial_segment_end_ts, current_length_ms); + + if (current_length_ms > (uint64_t)gf->partial_latency) { + // send partial segment to inference + obs_log(gf->log_level, "Partial segment -> send to inference"); + last_vad_state.last_partial_segment_end_ts = + last_vad_state.end_ts_offset_ms; + + // run vad on the current buffer + std::vector vad_input; + vad_input.resize(gf->whisper_buffer.size / sizeof(float)); + circlebuf_peek_front(&gf->whisper_buffer, vad_input.data(), + vad_input.size() * sizeof(float)); + + obs_log(gf->log_level, "sending %d frames to vad, %.1f ms", + vad_input.size(), + (float)vad_input.size() * 1000.0f / (float)WHISPER_SAMPLE_RATE); + { + ProfileScope("vad->process"); + gf->vad->process(vad_input, true); + } + + if (gf->vad->get_speech_timestamps().size() > 0) { + // VAD detected speech in the partial segment + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, + VAD_STATE_PARTIAL); + } else { + // VAD detected silence in the partial segment + obs_log(gf->log_level, "VAD detected silence in partial segment"); + // pop the partial segment from the whisper buffer, save some audio for the next segment + const size_t num_bytes_to_keep = + (WHISPER_SAMPLE_RATE / 4) * sizeof(float); + circlebuf_pop_front(&gf->whisper_buffer, nullptr, + gf->whisper_buffer.size - num_bytes_to_keep); + } + } + } + + return last_vad_state; +} + +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file) +{ + // initialize Silero VAD +#ifdef _WIN32 + // convert mbstring to wstring + int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, + strlen(silero_vad_model_file), NULL, 0); + std::wstring silero_vad_model_path(count, 0); + MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), + &silero_vad_model_path[0], count); + obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); +#else + std::string silero_vad_model_path = silero_vad_model_file; + obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); +#endif + // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py + // for silero vad parameters + gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 32, 0.5f, 100, + 100, 100)); +} diff --git a/src/whisper-utils/vad-processing.h b/src/whisper-utils/vad-processing.h new file mode 100644 index 0000000..996002b --- /dev/null +++ b/src/whisper-utils/vad-processing.h @@ -0,0 +1,18 @@ +#ifndef VAD_PROCESSING_H +#define VAD_PROCESSING_H + +enum VadState { VAD_STATE_WAS_ON = 0, VAD_STATE_WAS_OFF, VAD_STATE_IS_OFF, VAD_STATE_PARTIAL }; +enum VadMode { VAD_MODE_ACTIVE = 0, VAD_MODE_HYBRID, VAD_MODE_DISABLED }; + +struct vad_state { + bool vad_on; + uint64_t start_ts_offest_ms; + uint64_t end_ts_offset_ms; + uint64_t last_partial_segment_end_ts; +}; + +vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state); +vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state); +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file); + +#endif // VAD_PROCESSING_H diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 6d2d76e..6da91d9 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -17,18 +17,12 @@ #endif #include "model-utils/model-find-utils.h" +#include "vad-processing.h" #include #include #include -struct vad_state { - bool vad_on; - uint64_t start_ts_offest_ms; - uint64_t end_ts_offset_ms; - uint64_t last_partial_segment_end_ts; -}; - struct whisper_context *init_whisper_context(const std::string &model_path_in, struct transcription_filter_data *gf) { @@ -161,6 +155,10 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter float *pcm32f_data = (float *)pcm32f_data_; size_t pcm32f_size = pcm32f_num_samples; + // incoming duration in ms + const uint64_t incoming_duration_ms = + (uint64_t)(pcm32f_num_samples * 1000 / WHISPER_SAMPLE_RATE); + if (pcm32f_num_samples < WHISPER_SAMPLE_RATE) { obs_log(gf->log_level, "Speech segment is less than 1 second, padding with zeros to 1 second"); @@ -175,7 +173,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter } // duration in ms - const uint64_t duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE); + const uint64_t whisper_duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE); std::lock_guard lock(gf->whisper_ctx_mutex); if (gf->whisper_context == nullptr) { @@ -183,9 +181,19 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } + if (gf->n_context_sentences > 0 && !gf->last_transcription_sentence.empty()) { + // set the initial prompt to the last transcription sentences (concatenated) + std::string initial_prompt = gf->last_transcription_sentence[0]; + for (size_t i = 1; i < gf->last_transcription_sentence.size(); ++i) { + initial_prompt += " " + gf->last_transcription_sentence[i]; + } + gf->whisper_params.initial_prompt = initial_prompt.c_str(); + obs_log(gf->log_level, "Initial prompt: %s", gf->whisper_params.initial_prompt); + } + // run the inference int whisper_full_result = -1; - gf->whisper_params.duration_ms = (int)(duration_ms); + gf->whisper_params.duration_ms = (int)(whisper_duration_ms); try { whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params, pcm32f_data, (int)pcm32f_size); @@ -243,13 +251,13 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json if (token.id > 50365 && token.id <= 51865) { const float time = ((float)token.id - 50365.0f) * 0.02f; - const float duration_s = (float)duration_ms / 1000.0f; - const float ratio = - std::max(time, duration_s) / std::min(time, duration_s); + const float duration_s = (float)incoming_duration_ms / 1000.0f; + const float ratio = time / duration_s; obs_log(gf->log_level, - "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f.", - token.id, time, duration_s, ratio); - if (ratio > 3.0f) { + "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f. Threshold %.2f", + token.id, time, duration_s, ratio, + gf->duration_filter_threshold); + if (ratio > gf->duration_filter_threshold) { // ratio is too high, skip this detection obs_log(gf->log_level, "Time token ratio too high, skipping"); @@ -263,8 +271,8 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter text += token_str; tokens.push_back(token); } - obs_log(gf->log_level, "S %d, Token %d: %d\t%s\tp: %.3f [keep: %d]", - n_segment, j, token.id, token_str, token.p, keep); + obs_log(gf->log_level, "S %d, T %d: %d\t%s\tp: %.3f [keep: %d]", n_segment, + j, token.id, token_str, token.p, keep); } } sentence_p /= (float)tokens.size(); @@ -327,233 +335,6 @@ void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_o bfree(pcm32f_data); } -vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) -{ - uint32_t num_frames_from_infos = 0; - uint64_t start_timestamp_offset_ns = 0; - uint64_t end_timestamp_offset_ns = 0; - size_t overlap_size = 0; - - for (size_t c = 0; c < gf->channels; c++) { - // zero the rest of copy_buffers - memset(gf->copy_buffers[c] + overlap_size, 0, - (gf->frames - overlap_size) * sizeof(float)); - } - - { - // scoped lock the buffer mutex - std::lock_guard lock(gf->whisper_buf_mutex); - - obs_log(gf->log_level, - "vad based segmentation. currently %lu bytes in the audio input buffer", - gf->input_buffers[0].size); - - // max number of frames is 10 seconds worth of audio - const size_t max_num_frames = gf->sample_rate * 10; - - // pop all infos from the info buffer and mark the beginning timestamp from the first - // info as the beginning timestamp of the segment - struct transcription_filter_audio_info info_from_buf = {0}; - const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); - while (gf->info_buffer.size >= size_of_audio_info) { - circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); - num_frames_from_infos += info_from_buf.frames; - if (start_timestamp_offset_ns == 0) { - start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; - } - // Check if we're within the needed segment length - if (num_frames_from_infos > max_num_frames) { - // too big, push the last info into the buffer's front where it was - num_frames_from_infos -= info_from_buf.frames; - circlebuf_push_front(&gf->info_buffer, &info_from_buf, - size_of_audio_info); - break; - } - } - end_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; - - if (start_timestamp_offset_ns > end_timestamp_offset_ns) { - // this may happen when the incoming media has a timestamp reset - // in this case, we should figure out the start timestamp from the end timestamp - // and the number of frames - start_timestamp_offset_ns = - end_timestamp_offset_ns - - num_frames_from_infos * 1000000000 / gf->sample_rate; - } - - /* Pop from input circlebuf */ - for (size_t c = 0; c < gf->channels; c++) { - // Push the new data to copy_buffers[c] - circlebuf_pop_front(&gf->input_buffers[c], - gf->copy_buffers[c] + overlap_size, - num_frames_from_infos * sizeof(float)); - } - } - - obs_log(gf->log_level, "found %d frames from info buffer. %lu in overlap", - num_frames_from_infos, overlap_size); - gf->last_num_frames = num_frames_from_infos + overlap_size; - - { - // resample to 16kHz - float *resampled_16khz[MAX_PREPROC_CHANNELS]; - uint32_t resampled_16khz_frames; - uint64_t ts_offset; - { - ProfileScope("resample"); - audio_resampler_resample(gf->resampler_to_whisper, - (uint8_t **)resampled_16khz, - &resampled_16khz_frames, &ts_offset, - (const uint8_t **)gf->copy_buffers, - (uint32_t)num_frames_from_infos); - } - - obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms", - (int)gf->channels, (int)resampled_16khz_frames, - (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f); - circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], - resampled_16khz_frames * sizeof(float)); - } - - if (gf->resampled_buffer.size < (gf->vad->get_window_size_samples() * sizeof(float))) - return last_vad_state; - - size_t len = - gf->resampled_buffer.size / (gf->vad->get_window_size_samples() * sizeof(float)); - - std::vector vad_input; - vad_input.resize(len * gf->vad->get_window_size_samples()); - circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), - vad_input.size() * sizeof(float)); - - obs_log(gf->log_level, "sending %d frames to vad", vad_input.size()); - { - ProfileScope("vad->process"); - gf->vad->process(vad_input, !last_vad_state.vad_on); - } - - const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000; - const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000; - - vad_state current_vad_state = {false, start_ts_offset_ms, end_ts_offset_ms, - last_vad_state.last_partial_segment_end_ts}; - - std::vector stamps = gf->vad->get_speech_timestamps(); - if (stamps.size() == 0) { - obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); - if (last_vad_state.vad_on) { - obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference"); - run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, - last_vad_state.end_ts_offset_ms, - VAD_STATE_WAS_ON); - current_vad_state.last_partial_segment_end_ts = 0; - } - - if (gf->enable_audio_chunks_callback) { - audio_chunk_callback(gf, vad_input.data(), vad_input.size(), - VAD_STATE_IS_OFF, - {DETECTION_RESULT_SILENCE, - "[silence]", - current_vad_state.start_ts_offest_ms, - current_vad_state.end_ts_offset_ms, - {}}); - } - - return current_vad_state; - } - - // process vad segments - for (size_t i = 0; i < stamps.size(); i++) { - int start_frame = stamps[i].start; - if (i > 0) { - // if this is not the first segment, start from the end of the previous segment - start_frame = stamps[i - 1].end; - } else { - // take at least 100ms of audio before the first speech segment, if available - start_frame = std::max(0, start_frame - WHISPER_SAMPLE_RATE / 10); - } - - int end_frame = stamps[i].end; - if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { - // take at least 100ms of audio after the last speech segment, if available - end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, - (int)vad_input.size()); - } - - const int number_of_frames = end_frame - start_frame; - - // push the data into gf-whisper_buffer - circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, - number_of_frames * sizeof(float)); - - obs_log(gf->log_level, - "VAD segment %d. pushed %d to %d (%d frames / %lu ms). current size: %lu bytes / %lu frames / %lu ms", - i, start_frame, end_frame, number_of_frames, - number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size, - gf->whisper_buffer.size / sizeof(float), - gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE); - - // segment "end" is in the middle of the buffer, send it to inference - if (stamps[i].end < (int)vad_input.size()) { - // new "ending" segment (not up to the end of the buffer) - obs_log(gf->log_level, "VAD segment end -> send to inference"); - // find the end timestamp of the segment - const uint64_t segment_end_ts = - start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; - run_inference_and_callbacks( - gf, last_vad_state.start_ts_offest_ms, segment_end_ts, - last_vad_state.vad_on ? VAD_STATE_WAS_ON : VAD_STATE_WAS_OFF); - current_vad_state.vad_on = false; - current_vad_state.start_ts_offest_ms = current_vad_state.end_ts_offset_ms; - current_vad_state.end_ts_offset_ms = 0; - current_vad_state.last_partial_segment_end_ts = 0; - last_vad_state = current_vad_state; - continue; - } - - // end not reached - speech is ongoing - current_vad_state.vad_on = true; - if (last_vad_state.vad_on) { - current_vad_state.start_ts_offest_ms = last_vad_state.start_ts_offest_ms; - } else { - current_vad_state.start_ts_offest_ms = - start_ts_offset_ms + start_frame * 1000 / WHISPER_SAMPLE_RATE; - } - obs_log(gf->log_level, "end not reached. vad state: start ts: %llu, end ts: %llu", - current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms); - - last_vad_state = current_vad_state; - - // if partial transcription is enabled, check if we should send a partial segment - if (!gf->partial_transcription) { - continue; - } - - // current length of audio in buffer - const uint64_t current_length_ms = - (current_vad_state.end_ts_offset_ms > 0 - ? current_vad_state.end_ts_offset_ms - : current_vad_state.start_ts_offest_ms) - - (current_vad_state.last_partial_segment_end_ts > 0 - ? current_vad_state.last_partial_segment_end_ts - : current_vad_state.start_ts_offest_ms); - obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", - current_vad_state.last_partial_segment_end_ts, current_length_ms); - - if (current_length_ms > (uint64_t)gf->partial_latency) { - current_vad_state.last_partial_segment_end_ts = - current_vad_state.end_ts_offset_ms; - // send partial segment to inference - obs_log(gf->log_level, "Partial segment -> send to inference"); - run_inference_and_callbacks(gf, current_vad_state.start_ts_offest_ms, - current_vad_state.end_ts_offset_ms, - VAD_STATE_PARTIAL); - } - } - - return current_vad_state; -} - void whisper_loop(void *data) { if (data == nullptr) { @@ -566,7 +347,7 @@ void whisper_loop(void *data) obs_log(gf->log_level, "Starting whisper thread"); - vad_state current_vad_state = {false, 0, 0, 0}; + vad_state current_vad_state = {false, now_ms(), 0, 0}; const char *whisper_loop_name = "Whisper loop"; profile_register_root(whisper_loop_name, 50 * 1000 * 1000); @@ -584,12 +365,16 @@ void whisper_loop(void *data) } } - current_vad_state = vad_based_segmentation(gf, current_vad_state); + if (gf->vad_mode == VAD_MODE_HYBRID) { + current_vad_state = hybrid_vad_segmentation(gf, current_vad_state); + } else if (gf->vad_mode == VAD_MODE_ACTIVE) { + current_vad_state = vad_based_segmentation(gf, current_vad_state); + } if (!gf->cleared_last_sub) { // check if we should clear the current sub depending on the minimum subtitle duration uint64_t now = now_ms(); - if ((now - gf->last_sub_render_time) > gf->min_sub_duration) { + if ((now - gf->last_sub_render_time) > gf->max_sub_duration) { // clear the current sub, call the callback with an empty string obs_log(gf->log_level, "Clearing current subtitle. now: %lu ms, last: %lu ms", now, diff --git a/src/whisper-utils/whisper-processing.h b/src/whisper-utils/whisper-processing.h index 5bc162b..a00f7cb 100644 --- a/src/whisper-utils/whisper-processing.h +++ b/src/whisper-utils/whisper-processing.h @@ -29,10 +29,10 @@ struct DetectionResultWithText { std::string language; }; -enum VadState { VAD_STATE_WAS_ON = 0, VAD_STATE_WAS_OFF, VAD_STATE_IS_OFF, VAD_STATE_PARTIAL }; - void whisper_loop(void *data); struct whisper_context *init_whisper_context(const std::string &model_path, struct transcription_filter_data *gf); +void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms, + uint64_t end_offset_ms, int vad_state); #endif // WHISPER_PROCESSING_H diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index c2e4929..84f3b0a 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -2,13 +2,10 @@ #include "plugin-support.h" #include "model-utils/model-downloader.h" #include "whisper-processing.h" +#include "vad-processing.h" #include -#ifdef _WIN32 -#include -#endif - void shutdown_whisper_thread(struct transcription_filter_data *gf) { obs_log(gf->log_level, "shutdown_whisper_thread"); @@ -40,21 +37,7 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, } // initialize Silero VAD -#ifdef _WIN32 - // convert mbstring to wstring - int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, - strlen(silero_vad_model_file), NULL, 0); - std::wstring silero_vad_model_path(count, 0); - MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), - &silero_vad_model_path[0], count); - obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); -#else - std::string silero_vad_model_path = silero_vad_model_file; - obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); -#endif - // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py - // for silero vad parameters - gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE)); + initialize_vad(gf, silero_vad_model_file); obs_log(gf->log_level, "Create whisper context"); gf->whisper_context = init_whisper_context(whisper_model_path, gf);