diff --git a/.github/workflows/build-project.yaml b/.github/workflows/build-project.yaml index addbd69..bbb98ff 100644 --- a/.github/workflows/build-project.yaml +++ b/.github/workflows/build-project.yaml @@ -271,6 +271,19 @@ jobs: "pluginName=${ProductName}" >> $env:GITHUB_OUTPUT "pluginVersion=${ProductVersion}" >> $env:GITHUB_OUTPUT + - name: Install vcpkg + id: vcpkg + run: | + git clone https://github.com/microsoft/vcpkg.git + cd vcpkg + .\bootstrap-vcpkg.bat + # Configure the VCPKG_ROOT and PATH environment variables + Add-Content $env:GITHUB_ENV "VCPKG_ROOT=${{ github.workspace }}/vcpkg" + Add-Content $env:GITHUB_PATH "${{ github.workspace }}/vcpkg" + # Install the necessary packages + cd .. + vcpkg install + - name: Build Plugin 🧱 uses: ./.github/actions/build-plugin with: diff --git a/.gitignore b/.gitignore index a064dbf..54b9297 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,8 @@ !README.md !/vendor !patch_libobs.diff +!vcpkg.json +!vcpkg-configuration.json # Exclude lock files *.lock.json diff --git a/CMakeLists.txt b/CMakeLists.txt index 108ff06..c9cfc60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,8 @@ include(cmake/BuildCTranslate2.cmake) include(cmake/BuildSentencepiece.cmake) target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ct2 sentencepiece) +include(cmake/FindOpenSSL.cmake) + set(USE_SYSTEM_ONNXRUNTIME OFF CACHE STRING "Use system ONNX Runtime") @@ -100,6 +102,7 @@ target_sources( src/whisper-utils/whisper-model-utils.cpp src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp + src/timed-metadata/timed-metadata-utils.cpp src/translation/language_codes.cpp src/translation/translation.cpp src/translation/translation-utils.cpp diff --git a/buildspec.json b/buildspec.json index 168715e..caf2e23 100644 --- a/buildspec.json +++ b/buildspec.json @@ -38,7 +38,7 @@ }, "name": "obs-localvocal", "displayName": "OBS Localvocal", - "version": "0.3.2", + "version": "0.3.3", "author": "Roy Shilkrot", "website": "https://github.com/occ-ai/obs-localvocal", "email": "roy.shil@gmail.com", diff --git a/cmake/FindOpenSSL.cmake b/cmake/FindOpenSSL.cmake new file mode 100644 index 0000000..d524b7a --- /dev/null +++ b/cmake/FindOpenSSL.cmake @@ -0,0 +1,30 @@ +if(WIN32) + set(OPENSSL_ROOT_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/vcpkg_installed/x64-windows" + CACHE STRING "Path to OpenSSL") +elseif(APPLE) + include(FetchContent) + + FetchContent_Declare( + openssl-macos-fetch + URL "https://github.com/occ-ai/occ-ai-dep-openssl/releases/download/0.0.1/openssl-3.3.1-macos.tar.gz" + URL_HASH SHA256=d578921b7168e21451f0b6e4ac4cb989c17abc6829c8c43f32136c1c2544ffde) + + FetchContent_MakeAvailable(openssl-macos-fetch) + + set(OPENSSL_ROOT_DIR + "${openssl-macos-fetch_SOURCE_DIR}" + CACHE STRING "Path to OpenSSL") +endif() + +find_package(OpenSSL REQUIRED) +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE OpenSSL::SSL) +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE OpenSSL::Crypto) + +# copy the openssl dlls to the release directory +if(WIN32) + set(OpenSS_LIB_NAMES "libcrypto-3-x64" "libssl-3-x64") + foreach(lib_name IN LISTS OpenSS_LIB_NAMES) + install(FILES ${OPENSSL_ROOT_DIR}/bin/${lib_name}.dll DESTINATION "obs-plugins/64bit") + endforeach() +endif() diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index e08b4af..0a2e956 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -72,6 +72,11 @@ buffer_num_lines="Number of lines" buffer_num_chars_per_line="Amount per line" buffer_output_type="Output type" open_filter_ui="Setup Filter and Replace" +amazon_ivs_parameters="Amazon IVS Integration" +amazon_ivs_channel_arn="Amazon IVS Channel ARN" +aws_access_key="AWS Access Key" +aws_secret_key="AWS Secret Key" +aws_region="AWS Region" advanced_settings_mode="Mode" simple_mode="Simple" advanced_mode="Advanced" diff --git a/src/model-utils/model-find-utils.cpp b/src/model-utils/model-find-utils.cpp index d2bb48f..dc680e7 100644 --- a/src/model-utils/model-find-utils.cpp +++ b/src/model-utils/model-find-utils.cpp @@ -24,11 +24,28 @@ std::string find_file_in_folder_by_name(const std::string &folder_path, std::string find_file_in_folder_by_regex_expression(const std::string &folder_path, const std::string &file_name_regex) { - for (const auto &entry : std::filesystem::directory_iterator(folder_path)) { - if (std::regex_match(entry.path().filename().string(), - std::regex(file_name_regex))) { - return entry.path().string(); + if (!std::filesystem::exists(folder_path)) { + obs_log(LOG_ERROR, "Folder does not exist: %s", folder_path.c_str()); + return ""; + } + if (!std::filesystem::is_directory(folder_path)) { + obs_log(LOG_ERROR, "Path is not a folder: %s", folder_path.c_str()); + return ""; + } + if (file_name_regex.empty()) { + obs_log(LOG_ERROR, "Empty file name regex"); + return ""; + } + try { + for (const auto &entry : std::filesystem::directory_iterator(folder_path)) { + if (std::regex_match(entry.path().filename().string(), + std::regex(file_name_regex))) { + return entry.path().string(); + } } + } catch (const std::exception &e) { + obs_log(LOG_ERROR, "Error finding file in folder by regex expression: %s", + e.what()); } return ""; } diff --git a/src/timed-metadata/timed-metadata-utils.cpp b/src/timed-metadata/timed-metadata-utils.cpp new file mode 100644 index 0000000..6ff7f97 --- /dev/null +++ b/src/timed-metadata/timed-metadata-utils.cpp @@ -0,0 +1,272 @@ + +#include "plugin-support.h" +#include "timed-metadata-utils.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +// HMAC SHA-256 function +std::string hmacSha256(const std::string &key, const std::string &data, bool isHexKey = false) +{ + unsigned char *digest; + size_t len = EVP_MAX_MD_SIZE; + digest = (unsigned char *)bzalloc(len); + + EVP_PKEY *pkey = nullptr; + if (isHexKey) { + // Convert hex string to binary data + std::vector hexKey; + for (size_t i = 0; i < key.length(); i += 2) { + std::string byteString = key.substr(i, 2); + unsigned char byte = (unsigned char)strtol(byteString.c_str(), NULL, 16); + hexKey.push_back(byte); + } + pkey = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, hexKey.data(), (int)hexKey.size()); + } else { + pkey = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, (unsigned char *)key.c_str(), + (int)key.length()); + } + + EVP_MD_CTX *ctx = EVP_MD_CTX_new(); + EVP_DigestSignInit(ctx, NULL, EVP_sha256(), NULL, pkey); + EVP_DigestSignUpdate(ctx, data.c_str(), data.length()); + EVP_DigestSignFinal(ctx, digest, &len); + + EVP_PKEY_free(pkey); + EVP_MD_CTX_free(ctx); + + std::stringstream ss; + for (size_t i = 0; i < len; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') << (int)digest[i]; + } + bfree(digest); + return ss.str(); +} + +std::string sha256(const std::string &data) +{ + unsigned char hash[EVP_MAX_MD_SIZE]; + unsigned int lengthOfHash = 0; + + EVP_MD_CTX *context = EVP_MD_CTX_new(); + + if (context != nullptr) { + if (EVP_DigestInit_ex(context, EVP_sha256(), nullptr)) { + if (EVP_DigestUpdate(context, data.c_str(), data.length())) { + if (EVP_DigestFinal_ex(context, hash, &lengthOfHash)) { + EVP_MD_CTX_free(context); + + std::stringstream ss; + for (unsigned int i = 0; i < lengthOfHash; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << (int)hash[i]; + } + return ss.str(); + } + } + } + EVP_MD_CTX_free(context); + } + + return ""; +} + +std::string getCurrentTimestamp() +{ + auto now = std::chrono::system_clock::now(); + auto in_time_t = std::chrono::system_clock::to_time_t(now); + std::stringstream ss; + ss << std::put_time(std::gmtime(&in_time_t), "%Y%m%dT%H%M%SZ"); + return ss.str(); +} + +std::string getCurrentDate() +{ + auto now = std::chrono::system_clock::now(); + auto in_time_t = std::chrono::system_clock::to_time_t(now); + std::stringstream ss; + ss << std::put_time(std::gmtime(&in_time_t), "%Y%m%d"); + return ss.str(); +} + +size_t WriteCallback(void *ptr, size_t size, size_t nmemb, std::string *data) +{ + data->append((char *)ptr, size * nmemb); + return size * nmemb; +} + +void send_timed_metadata_to_ivs_endpoint(struct transcription_filter_data *gf, + Translation_Mode mode, const std::string &source_text, + const std::string &source_lang, + const std::string &target_text, + const std::string &target_lang) +{ + // below 4 should be from a configuration + std::string AWS_ACCESS_KEY = gf->aws_access_key; + std::string AWS_SECRET_KEY = gf->aws_secret_key; + std::string CHANNEL_ARN = gf->ivs_channel_arn; + std::string REGION = gf->aws_region; + + std::string SERVICE = "ivs"; + std::string HOST = "ivs." + REGION + ".amazonaws.com"; + + // Construct the inner JSON string + nlohmann::json inner_meta_data; + if (mode == NON_WHISPER_TRANSLATE) { + obs_log(gf->log_level, + "send_timed_metadata_to_ivs_endpoint - NON_WHISPER_TRANSLATE"); + nlohmann::json array; + if (!source_text.empty()) { + array.push_back({{"language", source_lang}, {"text", source_text}}); + } + if (!target_text.empty()) { + array.push_back({{"language", target_lang}, {"text", target_text}}); + } + if (array.empty()) { + obs_log(gf->log_level, + "send_timed_metadata_to_ivs_endpoint - source and target text empty"); + return; + } + inner_meta_data = {{"captions", array}}; + } else if (mode == WHISPER_TRANSLATE) { + if (target_text.empty()) { + obs_log(gf->log_level, + "send_timed_metadata_to_ivs_endpoint - target text empty"); + return; + } + obs_log(gf->log_level, "send_timed_metadata_to_ivs_endpoint - WHISPER_TRANSLATE"); + inner_meta_data = { + {"captions", {{{"language", target_lang}, {"text", target_text}}}}}; + } else { + if (source_text.empty()) { + obs_log(gf->log_level, + "send_timed_metadata_to_ivs_endpoint - source text empty"); + return; + } + obs_log(gf->log_level, "send_timed_metadata_to_ivs_endpoint - transcription mode"); + inner_meta_data = { + {"captions", {{{"language", source_lang}, {"text", source_text}}}}}; + } + + // Construct the outer JSON string + nlohmann::json inner_meta_data_as_string = inner_meta_data.dump(); + std::string METADATA = R"({ + "channelArn": ")" + CHANNEL_ARN + + R"(", + "metadata": )" + inner_meta_data_as_string.dump() + + R"( + })"; + + std::string DATE = getCurrentDate(); + std::string TIMESTAMP = getCurrentTimestamp(); + std::string PAYLOAD_HASH = sha256(METADATA); + + std::ostringstream canonicalRequest; + canonicalRequest << "POST\n" + << "/PutMetadata\n" + << "\n" + << "content-type:application/json\n" + << "host:" << HOST << "\n" + << "x-amz-date:" << TIMESTAMP << "\n" + << "\n" + << "content-type;host;x-amz-date\n" + << PAYLOAD_HASH; + std::string CANONICAL_REQUEST = canonicalRequest.str(); + std::string HASHED_CANONICAL_REQUEST = sha256(CANONICAL_REQUEST); + + std::string ALGORITHM = "AWS4-HMAC-SHA256"; + std::string CREDENTIAL_SCOPE = DATE + "/" + REGION + "/" + SERVICE + "/aws4_request"; + std::ostringstream stringToSign; + stringToSign << ALGORITHM << "\n" + << TIMESTAMP << "\n" + << CREDENTIAL_SCOPE << "\n" + << HASHED_CANONICAL_REQUEST; + std::string STRING_TO_SIGN = stringToSign.str(); + + std::string KEY = "AWS4" + AWS_SECRET_KEY; + std::string DATE_KEY = hmacSha256(KEY, DATE); + std::string REGION_KEY = hmacSha256(DATE_KEY, REGION, true); + std::string SERVICE_KEY = hmacSha256(REGION_KEY, SERVICE, true); + std::string SIGNING_KEY = hmacSha256(SERVICE_KEY, "aws4_request", true); + std::string SIGNATURE = hmacSha256(SIGNING_KEY, STRING_TO_SIGN, true); + + std::ostringstream authHeader; + authHeader << ALGORITHM << " Credential=" << AWS_ACCESS_KEY << "/" << CREDENTIAL_SCOPE + << ", SignedHeaders=content-type;host;x-amz-date, Signature=" << SIGNATURE; + + std::string AUTH_HEADER = authHeader.str(); + + // Initialize CURL and set options + CURL *curl; + CURLcode res; + curl = curl_easy_init(); + if (!curl) { + obs_log(LOG_ERROR, + "send_timed_metadata_to_ivs_endpoint failed: curl_easy_init failed"); + return; + } + + curl_easy_setopt(curl, CURLOPT_URL, ("https://" + HOST + "/PutMetadata").c_str()); + struct curl_slist *headers = NULL; + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, ("Host: " + HOST).c_str()); + headers = curl_slist_append(headers, ("x-amz-date: " + TIMESTAMP).c_str()); + headers = curl_slist_append(headers, ("Authorization: " + AUTH_HEADER).c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, METADATA.c_str()); + + std::string response_string; + std::string header_string; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, &header_string); + + res = curl_easy_perform(curl); + if (res != CURLE_OK) { + obs_log(LOG_WARNING, "send_timed_metadata_to_ivs_endpoint failed:%s", + curl_easy_strerror(res)); + } else { + long response_code; + // Get the HTTP response code + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &response_code); + obs_log(gf->log_level, "HTTP Status code: %ld", response_code); + if (response_code != 204) { + obs_log(LOG_WARNING, "HTTP response: %s", response_string.c_str()); + } + } + curl_slist_free_all(headers); + curl_easy_cleanup(curl); +} + +// source: transcription text, target: translation text +void send_timed_metadata_to_server(struct transcription_filter_data *gf, Translation_Mode mode, + const std::string &source_text, const std::string &source_lang, + const std::string &target_text, const std::string &target_lang) +{ + if (gf->aws_access_key.empty() || gf->aws_secret_key.empty() || + gf->ivs_channel_arn.empty() || gf->aws_region.empty()) { + obs_log(gf->log_level, + "send_timed_metadata_to_server failed: IVS settings not set"); + return; + } + + std::thread send_timed_metadata_thread([=]() { + send_timed_metadata_to_ivs_endpoint(gf, mode, source_text, source_lang, target_text, + target_lang); + }); + send_timed_metadata_thread.detach(); +} diff --git a/src/timed-metadata/timed-metadata-utils.h b/src/timed-metadata/timed-metadata-utils.h new file mode 100644 index 0000000..bca8c9e --- /dev/null +++ b/src/timed-metadata/timed-metadata-utils.h @@ -0,0 +1,15 @@ +#ifndef TIMED_METADATA_UTILS_H +#define TIMED_METADATA_UTILS_H + +#include +#include + +#include "transcription-filter-data.h" + +enum Translation_Mode { WHISPER_TRANSLATE, NON_WHISPER_TRANSLATE, TRANSCRIBE }; + +void send_timed_metadata_to_server(struct transcription_filter_data *gf, Translation_Mode mode, + const std::string &source_text, const std::string &source_lang, + const std::string &target_text, const std::string &target_lang); + +#endif // TIMED_METADATA_UTILS_H diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 3cdd586..79291cb 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -19,6 +19,7 @@ #include "translation/translation-includes.h" #include "whisper-utils/whisper-utils.h" #include "whisper-utils/whisper-model-utils.h" +#include "timed-metadata/timed-metadata-utils.h" #include "translation/language_codes.h" void send_caption_to_source(const std::string &target_source_name, const std::string &caption, @@ -189,6 +190,7 @@ void set_text_callback(struct transcription_filter_data *gf, { DetectionResultWithText result = resultIn; if (!result.text.empty() && result.result == DETECTION_RESULT_SPEECH) { + // this sub should be rendered - update the last sub render time gf->last_sub_render_time = now_ms(); gf->cleared_last_sub = false; } @@ -221,11 +223,22 @@ void set_text_callback(struct transcription_filter_data *gf, } } + // time the translation + uint64_t start_time = now_ms(); + // send the sentence to translation (if enabled) std::string translated_sentence = send_sentence_to_translation(str_copy, gf, result.language); if (gf->translate) { + // log the translation time + obs_log(gf->log_level, "Translation time: %llu ms", now_ms() - start_time); + + // send the translated sentence to the server + send_timed_metadata_to_server(gf, NON_WHISPER_TRANSLATE, str_copy, result.language, + translated_sentence, gf->target_lang); + + // send the translated sentence to the selected output if (gf->translation_output == "none") { // overwrite the original text with the translated text str_copy = translated_sentence; @@ -238,6 +251,8 @@ void set_text_callback(struct transcription_filter_data *gf, gf); } } + } else { + send_timed_metadata_to_server(gf, TRANSCRIBE, str_copy, result.language, "", ""); } if (gf->buffered_output) { diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 271e289..7c5164f 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -51,6 +51,7 @@ struct transcription_filter_data { /* whisper */ std::string whisper_model_path; + bool whisper_model_loaded_new; struct whisper_context *whisper_context; whisper_full_params whisper_params; @@ -92,7 +93,6 @@ struct transcription_filter_data { // Output file path to write the subtitles std::string output_file_path; std::string whisper_model_file_currently_loaded; - bool whisper_model_loaded_new; // Use std for thread and mutex std::thread whisper_thread; @@ -114,6 +114,13 @@ struct transcription_filter_data { TokenBufferSegmentation buffered_output_output_type = TokenBufferSegmentation::SEGMENTATION_TOKEN; + // Amazon IVS settings + bool ivs_enabled = false; + std::string ivs_channel_arn; + std::string aws_access_key; + std::string aws_secret_key; + std::string aws_region; + // ctor transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() { diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 6f3b87f..e877e94 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -46,8 +46,9 @@ bool advanced_settings_callback(obs_properties_t *props, obs_property_t *propert UNUSED_PARAMETER(property); // If advanced settings is enabled, show the advanced settings group const bool show_hide = obs_data_get_int(settings, "advanced_settings_mode") == 1; - for (const std::string &prop_name : {"whisper_params_group", "buffered_output_group", - "log_group", "advanced_group", "file_output_enable"}) { + for (const std::string &prop_name : + {"whisper_params_group", "buffered_output_group", "log_group", "advanced_group", + "file_output_enable", "amazon_ivs_group"}) { obs_property_set_visible(obs_properties_get(props, prop_name.c_str()), show_hide); } translation_options_callback(props, NULL, settings); @@ -457,6 +458,26 @@ void add_general_group_properties(obs_properties_t *ppts) } } +void add_amazon_ivs_group_properties(obs_properties_t *ppts) +{ + // add group for Amazon IVS settings + obs_properties_t *amazon_ivs_group = obs_properties_create(); + obs_properties_add_group(ppts, "amazon_ivs_group", MT_("amazon_ivs_parameters"), + OBS_GROUP_CHECKABLE, amazon_ivs_group); + // add Amazon IVS channel ARN + obs_properties_add_text(amazon_ivs_group, "amazon_ivs_channel_arn", + MT_("amazon_ivs_channel_arn"), OBS_TEXT_DEFAULT); + // add AWS_ACCESS_KEY + obs_properties_add_text(amazon_ivs_group, "aws_access_key", MT_("aws_access_key"), + OBS_TEXT_DEFAULT); + // add AWS_SECRET_KEY + obs_properties_add_text(amazon_ivs_group, "aws_secret_key", MT_("aws_secret_key"), + OBS_TEXT_PASSWORD); + // add region + obs_properties_add_text(amazon_ivs_group, "aws_region", MT_("aws_region"), + OBS_TEXT_DEFAULT); +} + obs_properties_t *transcription_filter_properties(void *data) { struct transcription_filter_data *gf = @@ -480,6 +501,7 @@ obs_properties_t *transcription_filter_properties(void *data) add_buffered_output_group_properties(ppts); add_advanced_group_properties(ppts, gf); add_logging_group_properties(ppts); + add_amazon_ivs_group_properties(ppts); add_whisper_params_group_properties(ppts); // Add a informative text about the plugin diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index b4d8ce9..682ee77 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -279,6 +279,14 @@ void transcription_filter_update(void *data, obs_data_t *s) } } + // Amazon IVS settings + gf->ivs_enabled = obs_data_get_bool(s, "amazon_ivs_group"); + gf->ivs_channel_arn = obs_data_get_string(s, "amazon_ivs_channel_arn"); + gf->aws_access_key = obs_data_get_string(s, "aws_access_key"); + gf->aws_secret_key = obs_data_get_string(s, "aws_secret_key"); + gf->aws_region = obs_data_get_string(s, "aws_region"); + + // translation settings bool new_translate = obs_data_get_bool(s, "translate"); gf->target_lang = obs_data_get_string(s, "translate_target_language"); gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context"); @@ -556,14 +564,14 @@ void transcription_filter_defaults(obs_data_t *s) { obs_log(LOG_DEBUG, "filter defaults"); - obs_data_set_default_bool(s, "buffered_output", false); + obs_data_set_default_bool(s, "buffered_output", true); obs_data_set_default_int(s, "buffer_num_lines", 2); - obs_data_set_default_int(s, "buffer_num_chars_per_line", 30); + obs_data_set_default_int(s, "buffer_num_chars_per_line", 8); obs_data_set_default_int(s, "buffer_output_type", - (int)TokenBufferSegmentation::SEGMENTATION_TOKEN); + (int)TokenBufferSegmentation::SEGMENTATION_WORD); obs_data_set_default_bool(s, "vad_enabled", true); - obs_data_set_default_double(s, "vad_threshold", 0.65); + obs_data_set_default_double(s, "vad_threshold", 0.95); obs_data_set_default_int(s, "log_level", LOG_DEBUG); obs_data_set_default_bool(s, "log_words", false); obs_data_set_default_bool(s, "caption_to_stream", false); @@ -584,6 +592,8 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_string(s, "translation_model_path_external", ""); obs_data_set_default_int(s, "translate_input_tokenization_style", INPUT_TOKENIZAION_M2M100); obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4); + obs_data_set_default_string(s, "filter_words_replace", + serialize_filter_words_replace({{"MBC.*", ""}}).c_str()); // translation options obs_data_set_default_double(s, "translation_sampling_temperature", 0.1); @@ -593,6 +603,13 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_int(s, "translation_no_repeat_ngram_size", 1); obs_data_set_default_int(s, "translation_max_input_length", 65); + // Amazon IVS + obs_data_set_default_bool(s, "amazon_ivs_group", false); + obs_data_set_default_string(s, "amazon_ivs_channel_arn", ""); + obs_data_set_default_string(s, "aws_access_key", ""); + obs_data_set_default_string(s, "aws_secret_key", ""); + obs_data_set_default_string(s, "aws_region", "us-west-2"); + // Whisper parameters obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); obs_data_set_default_string(s, "initial_prompt", ""); diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 3311b18..a35d8d8 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -172,6 +172,9 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } + // time the operation + auto start = std::chrono::high_resolution_clock::now(); + // run the inference int whisper_full_result = -1; gf->whisper_params.duration_ms = (int)(duration_ms); @@ -191,6 +194,11 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter bfree(pcm32f_data); } + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + obs_log(gf->log_level, "Transcription time: %d ms for %d ms of audio", + (int)duration.count(), (int)duration_ms); + std::string language = gf->whisper_params.language; if (gf->whisper_params.language == nullptr || strlen(gf->whisper_params.language) == 0 || strcmp(gf->whisper_params.language, "auto") == 0) { @@ -202,86 +210,79 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter if (whisper_full_result != 0) { obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; - } else { - float sentence_p = 0.0f; - std::string text = ""; - std::string tokenIds = ""; - std::vector tokens; - for (int n_segment = 0; n_segment < whisper_full_n_segments(gf->whisper_context); - ++n_segment) { - const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment); - for (int j = 0; j < n_tokens; ++j) { - // get token - whisper_token_data token = whisper_full_get_token_data( - gf->whisper_context, n_segment, j); - const char *token_str = - whisper_token_to_str(gf->whisper_context, token.id); - bool keep = true; - // if the token starts with '[' and ends with ']', don't keep it - if (token_str[0] == '[' && - token_str[strlen(token_str) - 1] == ']') { - keep = false; - } - // if this is a special token, don't keep it - if (token.id >= 50256) { - keep = false; - } - // if the second to last token is .id == 13 ('.'), don't keep it - if (j == n_tokens - 2 && token.id == 13) { - keep = false; - } - // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json - if (token.id > 50365 && token.id <= 51865) { - const float time = ((float)token.id - 50365.0f) * 0.02f; - const float duration_s = (float)duration_ms / 1000.0f; - const float ratio = std::max(time, duration_s) / - std::min(time, duration_s); + } + + float sentence_p = 0.0f; + std::string text = ""; + std::string tokenIds = ""; + std::vector tokens; + for (int n_segment = 0; n_segment < whisper_full_n_segments(gf->whisper_context); + ++n_segment) { + const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment); + for (int j = 0; j < n_tokens; ++j) { + // get token + whisper_token_data token = + whisper_full_get_token_data(gf->whisper_context, n_segment, j); + const char *token_str = whisper_token_to_str(gf->whisper_context, token.id); + bool keep = true; + // if the token starts with '[' and ends with ']', don't keep it + if (token_str[0] == '[' && token_str[strlen(token_str) - 1] == ']') { + keep = false; + } + // if this is a special token, don't keep it + if (token.id >= 50256) { + keep = false; + } + // if the second to last token is .id == 13 ('.'), don't keep it + if (j == n_tokens - 2 && token.id == 13) { + keep = false; + } + // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json + if (token.id > 50365 && token.id <= 51865) { + const float time = ((float)token.id - 50365.0f) * 0.02f; + const float duration_s = (float)duration_ms / 1000.0f; + const float ratio = + std::max(time, duration_s) / std::min(time, duration_s); + obs_log(gf->log_level, + "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f.", + token.id, time, duration_s, ratio); + if (ratio > 3.0f) { + // ratio is too high, skip this detection obs_log(gf->log_level, - "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f.", - token.id, time, duration_s, ratio); - if (ratio > 3.0f) { - // ratio is too high, skip this detection - obs_log(gf->log_level, - "Time token ratio too high, skipping"); - return {DETECTION_RESULT_SILENCE, - "", - t0, - t1, - {}, - language}; - } - keep = false; + "Time token ratio too high, skipping"); + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; } + keep = false; + } - if (keep) { - sentence_p += token.p; - text += token_str; - tokens.push_back(token); - } - obs_log(gf->log_level, "S %d, Token %d: %d\t%s\tp: %.3f [keep: %d]", - n_segment, j, token.id, token_str, token.p, keep); + if (keep) { + sentence_p += token.p; + text += token_str; + tokens.push_back(token); } + obs_log(gf->log_level, "S %d, Token %d: %d\t%s\tp: %.3f [keep: %d]", + n_segment, j, token.id, token_str, token.p, keep); } - sentence_p /= (float)tokens.size(); - if (sentence_p < gf->sentence_psum_accept_thresh) { - obs_log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping", - sentence_p, gf->sentence_psum_accept_thresh); - return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; - } - - obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str()); + } + sentence_p /= (float)tokens.size(); + if (sentence_p < gf->sentence_psum_accept_thresh) { + obs_log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping", + sentence_p, gf->sentence_psum_accept_thresh); + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; + } - if (gf->log_words) { - obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(), - to_timestamp(t1).c_str(), sentence_p, text.c_str()); - } + obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str()); - if (text.empty() || text == "." || text == " " || text == "\n") { - return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; - } + if (gf->log_words) { + obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(), + to_timestamp(t1).c_str(), sentence_p, text.c_str()); + } - return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens, language}; + if (text.empty() || text == "." || text == " " || text == "\n") { + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; } + + return {DETECTION_RESULT_SPEECH, text, t0, t1, tokens, language}; } void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms, diff --git a/vcpkg-configuration.json b/vcpkg-configuration.json new file mode 100644 index 0000000..7d31d40 --- /dev/null +++ b/vcpkg-configuration.json @@ -0,0 +1,14 @@ +{ + "default-registry": { + "kind": "git", + "baseline": "6f1ddd6b6878e7e66fcc35c65ba1d8feec2e01f8", + "repository": "https://github.com/microsoft/vcpkg" + }, + "registries": [ + { + "kind": "artifact", + "location": "https://github.com/microsoft/vcpkg-ce-catalog/archive/refs/heads/main.zip", + "name": "microsoft" + } + ] +} diff --git a/vcpkg.json b/vcpkg.json new file mode 100644 index 0000000..3ed9a36 --- /dev/null +++ b/vcpkg.json @@ -0,0 +1,5 @@ +{ + "dependencies": [ + "openssl" + ] +}