diff --git a/CMakeLists.txt b/CMakeLists.txt index c253488..b117275 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,12 +86,14 @@ target_sources( PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c + src/transcription-utils.cpp src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp src/model-utils/model-infos.cpp src/whisper-utils/whisper-processing.cpp src/whisper-utils/whisper-utils.cpp src/whisper-utils/silero-vad-onnx.cpp + src/whisper-utils/token-buffer-thread.cpp src/translation/translation.cpp src/utils.cpp) diff --git a/src/captions-thread.h b/src/captions-thread.h deleted file mode 100644 index 1cdb079..0000000 --- a/src/captions-thread.h +++ /dev/null @@ -1,118 +0,0 @@ -#ifndef CAPTIONS_THREAD_H -#define CAPTIONS_THREAD_H - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "plugin-support.h" - -class CaptionMonitor { -public: - // default constructor - CaptionMonitor() = default; - - ~CaptionMonitor() - { - { - std::lock_guard lock(queueMutex); - stop = true; - } - condVar.notify_all(); - workerThread.join(); - } - - void initialize(std::function callback_, size_t maxSize_, - std::chrono::seconds maxTime_) - { - this->callback = callback_; - this->maxSize = maxSize_; - this->maxTime = maxTime_; - this->initialized = true; - this->workerThread = std::thread(&CaptionMonitor::monitor, this); - } - - void addWords(const std::vector &words) - { - { - std::lock_guard lock(queueMutex); - for (const auto &word : words) { - wordQueue.push_back(word); - } - this->newDataAvailable = true; - } - condVar.notify_all(); - } - -private: - void monitor() - { - obs_log(LOG_INFO, "CaptionMonitor::monitor"); - auto startTime = std::chrono::steady_clock::now(); - while (true) { - std::unique_lock lock(this->queueMutex); - // wait for new data or stop signal - this->condVar.wait(lock, - [this] { return this->newDataAvailable || this->stop; }); - - if (this->stop) { - break; - } - - if (this->wordQueue.empty()) { - continue; - } - - // emit up to maxSize words from the wordQueue - std::vector emitted; - while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) { - emitted.push_back(this->wordQueue.front()); - this->wordQueue.pop_front(); - } - // emit the caption, joining the words with a space - std::string output; - for (const auto &word : emitted) { - output += word + " "; - } - this->callback(output); - // push back the words that were emitted, in reverse order - for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) { - this->wordQueue.push_front(*it); - } - - if (this->wordQueue.size() >= this->maxSize || - std::chrono::steady_clock::now() - startTime >= this->maxTime) { - // flush the queue if it's full or we've reached the max time - size_t words_to_flush = - std::min(this->wordQueue.size(), this->maxSize); - for (size_t i = 0; i < words_to_flush; ++i) { - wordQueue.pop_front(); - } - startTime = std::chrono::steady_clock::now(); - } - - newDataAvailable = false; - } - obs_log(LOG_INFO, "CaptionMonitor::monitor: done"); - } - - std::deque wordQueue; - std::thread workerThread; - std::mutex queueMutex; - std::condition_variable condVar; - std::function callback; - size_t maxSize; - std::chrono::seconds maxTime; - bool stop; - bool initialized = false; - bool newDataAvailable = false; -}; - -#endif // CAPTIONS_THREAD_H diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 599be0b..09d88b4 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -17,25 +17,13 @@ #include "translation/translation.h" #include "whisper-utils/silero-vad-onnx.h" -#include "captions-thread.h" +#include "whisper-utils/whisper-processing.h" +#include "whisper-utils/token-buffer-thread.h" #define MAX_PREPROC_CHANNELS 10 #define MT_ obs_module_text -enum DetectionResult { - DETECTION_RESULT_UNKNOWN = 0, - DETECTION_RESULT_SILENCE = 1, - DETECTION_RESULT_SPEECH = 2, -}; - -struct DetectionResultWithText { - DetectionResult result; - std::string text; - uint64_t start_timestamp_ms; - uint64_t end_timestamp_ms; -}; - struct transcription_filter_data { obs_source_t *context; // obs filter source (this filter) size_t channels; // number of channels @@ -116,7 +104,7 @@ struct transcription_filter_data { // translation context struct translation_context translation_ctx; - CaptionMonitor captions_monitor; + TokenBufferThread captions_monitor; // ctor transcription_filter_data() diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 572eb7f..bf4a48a 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -4,6 +4,7 @@ #include "plugin-support.h" #include "transcription-filter.h" #include "transcription-filter-data.h" +#include "transcription-utils.h" #include "model-utils/model-downloader.h" #include "whisper-utils/whisper-processing.h" #include "whisper-utils/whisper-language.h" @@ -187,40 +188,6 @@ void acquire_weak_text_source_ref(struct transcription_filter_data *gf) } } -#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) -#define is_trail_byte(c) (((c)&0xc0) == 0x80) - -inline int lead_byte_length(const uint8_t c) -{ - if ((c & 0xe0) == 0xc0) { - return 2; - } else if ((c & 0xf0) == 0xe0) { - return 3; - } else if ((c & 0xf8) == 0xf0) { - return 4; - } else { - return 1; - } -} - -inline bool is_valid_lead_byte(const uint8_t *c) -{ - const int length = lead_byte_length(c[0]); - if (length == 1) { - return true; - } - if (length == 2 && is_trail_byte(c[1])) { - return true; - } - if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { - return true; - } - if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { - return true; - } - return false; -} - void send_caption_to_source(const std::string &str_copy, struct transcription_filter_data *gf) { if (!gf->text_source_mutex) { @@ -267,44 +234,7 @@ void set_text_callback(struct transcription_filter_data *gf, } gf->last_sub_render_time = now; -#ifdef _WIN32 - // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs - // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. - std::stringstream ss; - uint8_t *c_str = (uint8_t *)result.text.c_str(); - for (size_t i = 0; i < result.text.size(); ++i) { - if (is_lead_byte(c_str[i])) { - // this is a unicode leading byte - // if the next char is 0xff - it's a bug char, replace it with 0x9f - if (c_str[i + 1] == 0xff) { - c_str[i + 1] = 0x9f; - } - if (!is_valid_lead_byte(c_str + i)) { - // This is a bug lead byte, because it's length 3 and the i+2 byte is also - // a lead byte - c_str[i] = c_str[i] - 0x20; - } - } else { - if (c_str[i] >= 0xf8) { - // this may be a malformed lead byte. - // lets see if it becomes a valid lead byte if we "fix" it - uint8_t buf_[4]; - buf_[0] = c_str[i] - 0x20; - buf_[1] = c_str[i + 1]; - buf_[2] = c_str[i + 2]; - buf_[3] = c_str[i + 3]; - if (is_valid_lead_byte(buf_)) { - // this is a malformed lead byte, fix it - c_str[i] = c_str[i] - 0x20; - } - } - } - } - - std::string str_copy = (char *)c_str; -#else - std::string str_copy = result.text; -#endif + std::string str_copy = fix_utf8(result.text); // remove trailing spaces, newlines, tabs or punctuation str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(), @@ -333,7 +263,7 @@ void set_text_callback(struct transcription_filter_data *gf, gf->last_text = str_copy; if (gf->buffered_output) { - gf->captions_monitor.addWords(split_words(str_copy)); + gf->captions_monitor.addWords(result.tokens); } if (gf->caption_to_stream) { @@ -673,13 +603,15 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->whisper_context = nullptr; gf->captions_monitor.initialize( + gf, [gf](const std::string &text) { obs_log(LOG_INFO, "Captions: %s", text.c_str()); if (gf->buffered_output) { send_caption_to_source(text, gf); } }, - 20, std::chrono::seconds(10)); + 20, + std::chrono::seconds(10)); obs_log(gf->log_level, "run update"); // get the settings updated on the filter data struct @@ -960,8 +892,8 @@ obs_properties_t *transcription_filter_properties(void *data) obs_properties_add_int_slider(ppts, "buffer_size_msec", MT_("buffer_size_msec"), 1000, DEFAULT_BUFFER_SIZE_MSEC, 250); - obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"), 50, 300, - 50); + obs_properties_add_int_slider(ppts, "overlap_size_msec", MT_("overlap_size_msec"), 250, DEFAULT_OVERLAP_SIZE_MSEC, + 250); obs_property_t *step_by_step_processing = obs_properties_add_bool( ppts, "step_by_step_processing", MT_("step_by_step_processing")); diff --git a/src/transcription-utils.cpp b/src/transcription-utils.cpp new file mode 100644 index 0000000..1f9c72f --- /dev/null +++ b/src/transcription-utils.cpp @@ -0,0 +1,80 @@ +#include "transcription-utils.h" + +#include + + +#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) +#define is_trail_byte(c) (((c)&0xc0) == 0x80) + +inline int lead_byte_length(const uint8_t c) +{ + if ((c & 0xe0) == 0xc0) { + return 2; + } else if ((c & 0xf0) == 0xe0) { + return 3; + } else if ((c & 0xf8) == 0xf0) { + return 4; + } else { + return 1; + } +} + +inline bool is_valid_lead_byte(const uint8_t *c) +{ + const int length = lead_byte_length(c[0]); + if (length == 1) { + return true; + } + if (length == 2 && is_trail_byte(c[1])) { + return true; + } + if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { + return true; + } + if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { + return true; + } + return false; +} + +std::string fix_utf8(const std::string &str) +{ +#ifdef _WIN32 + // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs + // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. + std::stringstream ss; + uint8_t *c_str = (uint8_t *)str.c_str(); + for (size_t i = 0; i < str.size(); ++i) { + if (is_lead_byte(c_str[i])) { + // this is a unicode leading byte + // if the next char is 0xff - it's a bug char, replace it with 0x9f + if (c_str[i + 1] == 0xff) { + c_str[i + 1] = 0x9f; + } + if (!is_valid_lead_byte(c_str + i)) { + // This is a bug lead byte, because it's length 3 and the i+2 byte is also + // a lead byte + c_str[i] = c_str[i] - 0x20; + } + } else { + if (c_str[i] >= 0xf8) { + // this may be a malformed lead byte. + // lets see if it becomes a valid lead byte if we "fix" it + uint8_t buf_[4]; + buf_[0] = c_str[i] - 0x20; + buf_[1] = c_str[i + 1]; + buf_[2] = c_str[i + 2]; + buf_[3] = c_str[i + 3]; + if (is_valid_lead_byte(buf_)) { + // this is a malformed lead byte, fix it + c_str[i] = c_str[i] - 0x20; + } + } + } + } + + return std::string((char*)c_str); +#else + return str; +#endif +} \ No newline at end of file diff --git a/src/transcription-utils.h b/src/transcription-utils.h new file mode 100644 index 0000000..cd00827 --- /dev/null +++ b/src/transcription-utils.h @@ -0,0 +1,8 @@ +#ifndef TRANSCRIPTION_UTILS_H +#define TRANSCRIPTION_UTILS_H + +#include + +std::string fix_utf8(const std::string &str); + +#endif // TRANSCRIPTION_UTILS_H diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp new file mode 100644 index 0000000..ff67953 --- /dev/null +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -0,0 +1,119 @@ +#include "token-buffer-thread.h" +#include "./whisper-utils.h" + +TokenBufferThread::~TokenBufferThread() +{ + { + std::lock_guard lock(queueMutex); + stop = true; + } + condVar.notify_all(); + workerThread.join(); +} + +void TokenBufferThread::initialize(struct transcription_filter_data* gf_, std::function callback_, size_t maxSize_, + std::chrono::seconds maxTime_) +{ + this->gf = gf_; + this->callback = callback_; + this->maxSize = maxSize_; + this->maxTime = maxTime_; + this->initialized = true; + this->workerThread = std::thread(&TokenBufferThread::monitor, this); +} + +void TokenBufferThread::log_token_vector(const std::vector &tokens) +{ + std::string output; + for (const auto &token : tokens) { + const char *token_str = whisper_token_to_str(gf->whisper_context, token.id); + output += token_str; + } + obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); +} + +void TokenBufferThread::addWords(const std::vector &words) +{ + obs_log(LOG_INFO, "TokenBufferThread::addWords"); + { + std::lock_guard lock(queueMutex); + + // convert current wordQueue to vector + std::vector currentWords(wordQueue.begin(), wordQueue.end()); + + log_token_vector(currentWords); + log_token_vector(words); + + // run reconstructSentence + std::vector reconstructed = reconstructSentence(currentWords, words); + + log_token_vector(reconstructed); + + // clear the wordQueue + wordQueue.clear(); + + // add the reconstructed sentence to the wordQueue + for (const auto &word : reconstructed) { + wordQueue.push_back(word); + } + } + condVar.notify_all(); +} + +void TokenBufferThread::monitor() +{ + obs_log(LOG_INFO, "TokenBufferThread::monitor"); + auto startTime = std::chrono::steady_clock::now(); + while (this->initialized && !this->stop) { + std::unique_lock lock(this->queueMutex); + // wait for new data or stop signal + this->condVar.wait(lock, + [this] { return this->newDataAvailable || this->stop; }); + + if (this->stop) { + break; + } + + if (this->wordQueue.empty()) { + continue; + } + + if (this->gf->whisper_context == nullptr) { + continue; + } + + // emit up to maxSize words from the wordQueue + std::vector emitted; + while (!this->wordQueue.empty() && emitted.size() <= this->maxSize) { + emitted.push_back(this->wordQueue.front()); + this->wordQueue.pop_front(); + } + obs_log(LOG_INFO, "TokenBufferThread::monitor: emitting %d words", emitted.size()); + log_token_vector(emitted); + // emit the caption from the tokens + std::string output; + for (const auto &token : emitted) { + const char *token_str = whisper_token_to_str(this->gf->whisper_context, token.id); + output += token_str; + } + this->callback(output); + // push back the words that were emitted, in reverse order + for (auto it = emitted.rbegin(); it != emitted.rend(); ++it) { + this->wordQueue.push_front(*it); + } + + if (this->wordQueue.size() >= this->maxSize || + std::chrono::steady_clock::now() - startTime >= this->maxTime) { + // flush the queue if it's full or we've reached the max time + size_t words_to_flush = + std::min(this->wordQueue.size(), this->maxSize); + for (size_t i = 0; i < words_to_flush; ++i) { + wordQueue.pop_front(); + } + startTime = std::chrono::steady_clock::now(); + } + + newDataAvailable = false; + } + obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); +} \ No newline at end of file diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h new file mode 100644 index 0000000..2f14267 --- /dev/null +++ b/src/whisper-utils/token-buffer-thread.h @@ -0,0 +1,48 @@ +#ifndef TOKEN_BUFFER_THREAD_H +#define TOKEN_BUFFER_THREAD_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include "plugin-support.h" + +struct transcription_filter_data; + +class TokenBufferThread { +public: + // default constructor + TokenBufferThread() = default; + + ~TokenBufferThread(); + void initialize(struct transcription_filter_data* gf, std::function callback_, size_t maxSize_, + std::chrono::seconds maxTime_); + + void addWords(const std::vector &words); + +private: + void monitor(); + void log_token_vector(const std::vector &tokens); + struct transcription_filter_data *gf; + std::deque wordQueue; + std::thread workerThread; + std::mutex queueMutex; + std::condition_variable condVar; + std::function callback; + size_t maxSize; + std::chrono::seconds maxTime; + bool stop; + bool initialized = false; + bool newDataAvailable = false; +}; + +#endif diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 744fd03..5d7cc61 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -211,16 +211,14 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter const float *pcm32f_data, size_t pcm32f_size, bool zero_start) { - static std::vector last_tokens; - if (gf == nullptr) { obs_log(LOG_ERROR, "run_whisper_inference: gf is null"); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } if (pcm32f_data == nullptr || pcm32f_size == 0) { obs_log(LOG_ERROR, "run_whisper_inference: pcm32f_data is null or size is 0"); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__, @@ -230,7 +228,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter std::lock_guard lock(*gf->whisper_ctx_mutex); if (gf->whisper_context == nullptr) { obs_log(LOG_WARNING, "whisper context is null"); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } // Get the duration in ms since the beginning of the stream (gf->start_timestamp_ms) @@ -249,12 +247,12 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what()); whisper_free(gf->whisper_context); gf->whisper_context = nullptr; - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } if (whisper_full_result != 0) { obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result); - return {DETECTION_RESULT_UNKNOWN, "", 0, 0}; + return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}}; } else { // duration in ms const uint64_t duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE); @@ -278,18 +276,19 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter whisper_full_get_token_data(gf->whisper_context, n_segment, j); const char *token_str = whisper_token_to_str(gf->whisper_context, token.id); bool keep = !end; - if (zero_start && token.t_dtw < 20) { - keep = false; - } + // if (zero_start && token.t_dtw < 20) { + // keep = false; + // } if (token.t_dtw == -1) { keep = false; } - if ((token.t_dtw < 20 || token.t_dtw > segment_cutoff) && token.p < 0.8) { - keep = false; - if (token.t_dtw > segment_cutoff) { - end = true; - } - } + if ((j == n_tokens - 2 || j == n_tokens - 3) && token.p < 0.5) { + keep = false; + } + // if the second to last token is .id == 13 ('.'), don't keep it + if (j == n_tokens - 2 && token.id == 13) { + keep = false; + } if (keep) { text += token_str; @@ -304,20 +303,6 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter obs_log(LOG_INFO, "Decoded sentence: '%s'", text.c_str()); obs_log(LOG_INFO, "Token IDs: %s", tokenIds.c_str()); - // reconstruct sentence - if (last_tokens.size() > 0) { - std::vector sentence = reconstructSentence(last_tokens, tokens); - std::string sentence_str = ""; - for (const whisper_token_data &token : sentence) { - const char *token_str = whisper_token_to_str(gf->whisper_context, token.id); - sentence_str += token_str; - } - obs_log(LOG_INFO, "Reconstructed sentence: '%s'", sentence_str.c_str()); - last_tokens = sentence; - } else { - last_tokens = tokens; - } - // convert text to lowercase std::string text_lower(text); std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower); @@ -333,10 +318,10 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter } if (text_lower.empty() || text_lower == ".") { - return {DETECTION_RESULT_SILENCE, "", 0, 0}; + return {DETECTION_RESULT_SILENCE, "", 0, 0, {}}; } - return {DETECTION_RESULT_SPEECH, text_lower, offset_ms, offset_ms + duration_ms}; + return {DETECTION_RESULT_SPEECH, text_lower, offset_ms, offset_ms + duration_ms, tokens}; } } @@ -453,13 +438,13 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) set_text_callback(gf, inference_result); } else if (inference_result.result == DETECTION_RESULT_SILENCE) { // output inference result to a text source - set_text_callback(gf, {inference_result.result, "[silence]", 0, 0}); + set_text_callback(gf, {inference_result.result, "[silence]", 0, 0, {}}); } } else { if (gf->log_words) { obs_log(LOG_INFO, "skipping inference"); } - set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0}); + set_text_callback(gf, {DETECTION_RESULT_UNKNOWN, "[skip]", 0, 0, {}}); } // end of timer @@ -470,6 +455,10 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) (int)duration); if (last_step_in_segment) { + const uint64_t overlap_size_ms = (uint64_t)(gf->overlap_frames * 1000 / gf->sample_rate); + obs_log(gf->log_level, + "copying %lu frames (%lu ms) from the end of the buffer (pos %lu) to the beginning", + gf->overlap_frames, overlap_size_ms, gf->last_num_frames - gf->overlap_frames); for (size_t c = 0; c < gf->channels; c++) { // This is the last step in the segment - reset the copy buffer (include overlap frames) // move overlap frames from the end of the last copy_buffers to the beginning @@ -479,8 +468,8 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) // zero out the rest of the buffer, just in case memset(gf->copy_buffers[c] + gf->overlap_frames, 0, (gf->frames - gf->overlap_frames) * sizeof(float)); - gf->last_num_frames = gf->overlap_frames; } + gf->last_num_frames = gf->overlap_frames; } } diff --git a/src/whisper-utils/whisper-processing.h b/src/whisper-utils/whisper-processing.h index d3215f5..409aaab 100644 --- a/src/whisper-utils/whisper-processing.h +++ b/src/whisper-utils/whisper-processing.h @@ -1,11 +1,27 @@ #ifndef WHISPER_PROCESSING_H #define WHISPER_PROCESSING_H +#include + // buffer size in msec #define DEFAULT_BUFFER_SIZE_MSEC 3000 // overlap in msec #define DEFAULT_OVERLAP_SIZE_MSEC 1000 +enum DetectionResult { + DETECTION_RESULT_UNKNOWN = 0, + DETECTION_RESULT_SILENCE = 1, + DETECTION_RESULT_SPEECH = 2, +}; + +struct DetectionResultWithText { + DetectionResult result; + std::string text; + uint64_t start_timestamp_ms; + uint64_t end_timestamp_ms; + std::vector tokens; +}; + void whisper_loop(void *data); struct whisper_context *init_whisper_context(const std::string &model_path, struct transcription_filter_data *gf); diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index 23e4c5c..183afbc 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -140,7 +140,9 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const #else std::string silero_vad_model_path = silero_vad_model_file; #endif - gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE)); + // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py + // for silero vad parameters + gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 1000, 200, 250)); gf->whisper_context = init_whisper_context(path, gf); if (gf->whisper_context == nullptr) { @@ -152,28 +154,39 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const gf->whisper_thread.swap(new_whisper_thread); } -// Dummy function that finds start of overlap; to be replaced with actual function +// Finds start of 2-token overlap between two sequences of tokens +// Returns a pair of indices of the first overlapping tokens in the two sequences +// If no overlap is found, the function returns {-1, -1} +// Allows for a single token mismatch in the overlap std::pair findStartOfOverlap(const std::vector &seq1, - const std::vector &seq2) + const std::vector &seq2) { - for (int i = 0; i < seq1.size(); ++i) { - for (int j = 0; j < seq2.size(); ++j) { - if (seq1[i].id == seq2[j].id) { - int k = 0; - while (i + k < seq1.size() && j + k < seq2.size() && - seq1[i + k].id == seq2[j + k].id) { - k++; - } - if (k > 1) { - return {i, j}; - } - } - } - } - return {-1, -1}; + if (seq1.empty() || seq2.empty() || seq1.size() == 1 || seq2.size() == 1) { + return {-1, -1}; + } + for (int i = 0; i < seq1.size() - 1; ++i) { + for (int j = 0; j < seq2.size() - 1; ++j) { + if (seq1[i].id == seq2[j].id) { + // Check if the next token in both sequences is the same + if (seq1[i + 1].id == seq2[j + 1].id) { + return {i, j}; + } + // 1-skip check on seq1 + if (i + 2 < seq1.size() && seq1[i + 2].id == seq2[j + 1].id) { + return {i, j}; + } + // 1-skip check on seq2 + if (j + 2 < seq2.size() && seq1[i + 1].id == seq2[j + 2].id) { + return {i, j}; + } + } + } + } + return {-1, -1}; } // Function to reconstruct a whole sentence from two sentences using overlap info +// If no overlap is found, the function returns the concatenation of the two sequences std::vector reconstructSentence(const std::vector &seq1, const std::vector &seq2) { @@ -181,7 +194,18 @@ std::vector reconstructSentence(const std::vector reconstructed; if (overlap.first == -1 || overlap.second == -1) { - return reconstructed; // Return empty if no overlap found + // Return concat of seq1 and seq2 if no overlap found + + // check if the last token of seq1 == the first token of seq2 + if (!seq1.empty() && !seq2.empty() && seq1.back().id == seq2.front().id) { + // don't add the last token of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1); + } else { + // add all tokens of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end()); + } + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + return reconstructed; } // Add tokens from the first sequence up to the overlap