Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start and stop based on filter enable status #111

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include "transcription-utils.h"
#include "translation/translation.h"
#include "translation/translation-includes.h"

#define SEND_TIMED_METADATA_URL "http://localhost:8080/timed-metadata"
#include "whisper-utils/whisper-utils.h"
#include "whisper-utils/whisper-model-utils.h"

void send_caption_to_source(const std::string &target_source_name, const std::string &caption,
struct transcription_filter_data *gf)
Expand Down Expand Up @@ -130,7 +130,13 @@ void set_text_callback(struct transcription_filter_data *gf,
if (gf->caption_to_stream) {
obs_output_t *streaming_output = obs_frontend_get_streaming_output();
if (streaming_output) {
obs_output_output_caption_text1(streaming_output, str_copy.c_str());
// calculate the duration in seconds
const uint64_t duration =
result.end_timestamp_ms - result.start_timestamp_ms;
obs_log(gf->log_level, "Sending caption to streaming output: %s",
str_copy.c_str());
obs_output_output_caption_text2(streaming_output, str_copy.c_str(),
(double)duration / 1000.0);
obs_output_release(streaming_output);
}
}
Expand Down Expand Up @@ -285,3 +291,23 @@ void media_stopped_callback(void *data_, calldata_t *cd)
gf_->active = false;
reset_caption_state(gf_);
}

void enable_callback(void *data_, calldata_t *cd)
{
transcription_filter_data *gf_ = static_cast<struct transcription_filter_data *>(data_);
bool enable = calldata_bool(cd, "enabled");
if (enable) {
obs_log(gf_->log_level, "enable_callback: enable");
gf_->active = true;
reset_caption_state(gf_);
// get filter settings from gf_->context
obs_data_t *settings = obs_source_get_settings(gf_->context);
update_whisper_model(gf_, settings);
obs_data_release(settings);
} else {
obs_log(gf_->log_level, "enable_callback: disable");
gf_->active = false;
reset_caption_state(gf_);
shutdown_whisper_thread(gf_);
}
}
1 change: 1 addition & 0 deletions src/transcription-filter-callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ void media_started_callback(void *data_, calldata_t *cd);
void media_pause_callback(void *data_, calldata_t *cd);
void media_restart_callback(void *data_, calldata_t *cd);
void media_stopped_callback(void *data_, calldata_t *cd);
void enable_callback(void *data_, calldata_t *cd);

#endif /* TRANSCRIPTION_FILTER_CALLBACKS_H */
1 change: 1 addition & 0 deletions src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ struct transcription_filter_data {
bool fix_utf8 = true;
bool enable_audio_chunks_callback = false;
bool source_signals_set = false;
bool initial_creation = true;

// Last transcription result
std::string last_text;
Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ struct obs_source_info transcription_filter_info = {
.deactivate = transcription_filter_deactivate,
.filter_audio = transcription_filter_filter_audio,
.filter_remove = transcription_filter_remove,
.show = transcription_filter_show,
.hide = transcription_filter_hide,
};
119 changes: 74 additions & 45 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
if (!audio) {
return nullptr;
}

if (data == nullptr) {
return audio;
}
Expand Down Expand Up @@ -137,6 +138,9 @@ void transcription_filter_destroy(void *data)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context);
signal_handler_disconnect(sh_filter, "enable", enable_callback, gf);

obs_log(gf->log_level, "filter destroy");
shutdown_whisper_thread(gf);

Expand Down Expand Up @@ -167,7 +171,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);

gf->log_level = (int)obs_data_get_int(s, "log_level");
gf->log_level = LOG_INFO; //(int)obs_data_get_int(s, "log_level");
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
gf->log_words = obs_data_get_bool(s, "log_words");
gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream");
Expand Down Expand Up @@ -293,51 +297,61 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->text_source_name = new_text_source_name;
}

obs_log(gf->log_level, "update whisper model");
update_whisper_model(gf, s);

obs_log(gf->log_level, "update whisper params");
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
{
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);

gf->sentence_psum_accept_thresh =
(float)obs_data_get_double(s, "sentence_psum_accept_thresh");
gf->sentence_psum_accept_thresh =
(float)obs_data_get_double(s, "sentence_psum_accept_thresh");

gf->whisper_params = whisper_full_default_params(
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
} else {
// take the language from gf->target_lang
gf->whisper_params.language = language_codes_2_reverse[gf->target_lang].c_str();
gf->whisper_params = whisper_full_default_params(
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
gf->whisper_params.language =
obs_data_get_string(s, "whisper_language_select");
} else {
// take the language from gf->target_lang
gf->whisper_params.language =
language_codes_2_reverse[gf->target_lang].c_str();
}
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate");
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
gf->whisper_params.suppress_non_speech_tokens =
obs_data_get_bool(s, "suppress_non_speech_tokens");
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");

if (gf->vad_enabled && gf->vad) {
const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold");
gf->vad->set_threshold(vad_threshold);
}
}
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate");
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
gf->whisper_params.suppress_non_speech_tokens =
obs_data_get_bool(s, "suppress_non_speech_tokens");
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");

if (gf->vad_enabled && gf->vad) {
const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold");
gf->vad->set_threshold(vad_threshold);

if (gf->initial_creation && obs_source_enabled(gf->context)) {
// source was enabled on creation
obs_data_t *settings = obs_source_get_settings(gf->context);
update_whisper_model(gf, settings);
obs_data_release(settings);
gf->active = true;
gf->initial_creation = false;
}
}

Expand Down Expand Up @@ -421,12 +435,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
gf->whisper_model_path = std::string(""); // The update function will set the model path
gf->whisper_context = nullptr;

signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context);
signal_handler_connect(sh_filter, "enable", enable_callback, gf);

obs_log(gf->log_level, "run update");
// get the settings updated on the filter data struct
transcription_filter_update(gf, settings);

gf->active = true;

// handle the event OBS_FRONTEND_EVENT_RECORDING_STARTING to reset the srt sentence number
// to match the subtitles with the recording
obs_frontend_add_event_callback(recording_state_callback, gf);
Expand Down Expand Up @@ -466,6 +481,20 @@ void transcription_filter_deactivate(void *data)
gf->active = false;
}

void transcription_filter_show(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_log(gf->log_level, "filter show");
}

void transcription_filter_hide(void *data)
{
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_log(gf->log_level, "filter hide");
}

void transcription_filter_defaults(obs_data_t *s)
{
obs_log(LOG_INFO, "filter defaults");
Expand Down Expand Up @@ -586,11 +615,11 @@ obs_properties_t *transcription_filter_properties(void *data)
whisper_model_path_external,
[](void *data_, obs_properties_t *props, obs_property_t *property,
obs_data_t *settings) {
obs_log(LOG_INFO, "whisper_model_path_external modified");
UNUSED_PARAMETER(property);
UNUSED_PARAMETER(props);
struct transcription_filter_data *gf_ =
static_cast<struct transcription_filter_data *>(data_);
obs_log(gf_->log_level, "whisper_model_path_external modified");
transcription_filter_update(gf_, settings);
return true;
},
Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ void transcription_filter_deactivate(void *data);
void transcription_filter_defaults(obs_data_t *s);
obs_properties_t *transcription_filter_properties(void *data);
void transcription_filter_remove(void *data, obs_source_t *source);
void transcription_filter_show(void *data);
void transcription_filter_hide(void *data);

const char *const PLUGIN_INFO_TEMPLATE =
"<a href=\"https://github.com/occ-ai/obs-localvocal/\">LocalVocal</a> (%1) by "
Expand Down
18 changes: 12 additions & 6 deletions src/whisper-utils/whisper-model-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
obs_log(LOG_ERROR, "Cannot find Silero VAD model file");
return;
}
obs_log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file);
std::string silero_vad_model_file_str = std::string(silero_vad_model_file);
bfree(silero_vad_model_file);

if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
is_external_model) {
Expand Down Expand Up @@ -49,14 +52,15 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
obs_log(LOG_WARNING, "Whisper model does not exist");
download_model_with_ui_dialog(
model_info,
[gf, new_model_path, silero_vad_model_file](
[gf, new_model_path, silero_vad_model_file_str](
int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO,
"Model download complete");
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(
gf, path, silero_vad_model_file);
gf, path,
silero_vad_model_file_str.c_str());
} else {
obs_log(LOG_ERROR, "Model download failed");
}
Expand All @@ -65,7 +69,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
// Model exists, just load it
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, model_file_found,
silero_vad_model_file);
silero_vad_model_file_str.c_str());
}
} else {
// new model is external file, get file location from file property
Expand All @@ -82,8 +86,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
} else {
shutdown_whisper_thread(gf);
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, external_model_file_path,
silero_vad_model_file);
start_whisper_thread_with_path(
gf, external_model_file_path,
silero_vad_model_file_str.c_str());
}
}
}
Expand All @@ -101,6 +106,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
gf->enable_token_ts_dtw, new_dtw_timestamps);
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
shutdown_whisper_thread(gf);
start_whisper_thread_with_path(gf, gf->whisper_model_path, silero_vad_model_file);
start_whisper_thread_with_path(gf, gf->whisper_model_path,
silero_vad_model_file_str.c_str());
}
}
19 changes: 15 additions & 4 deletions src/whisper-utils/whisper-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

#include <obs-module.h>

#ifdef _WIN32
#include <Windows.h>
#endif

void shutdown_whisper_thread(struct transcription_filter_data *gf)
{
obs_log(gf->log_level, "shutdown_whisper_thread");
Expand All @@ -27,7 +31,8 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf,
const std::string &whisper_model_path,
const char *silero_vad_model_file)
{
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", whisper_model_path.c_str());
obs_log(gf->log_level, "start_whisper_thread_with_path: %s, silero model path: %s",
whisper_model_path.c_str(), silero_vad_model_file);
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
if (gf->whisper_context != nullptr) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null");
Expand All @@ -36,16 +41,22 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf,

// initialize Silero VAD
#ifdef _WIN32
std::wstring silero_vad_model_path;
silero_vad_model_path.assign(silero_vad_model_file,
silero_vad_model_file + strlen(silero_vad_model_file));
// convert mbstring to wstring
int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file,
strlen(silero_vad_model_file), NULL, 0);
std::wstring silero_vad_model_path(count, 0);
MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file),
&silero_vad_model_path[0], count);
obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str());
#else
std::string silero_vad_model_path = silero_vad_model_file;
obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str());
#endif
// roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
// for silero vad parameters
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE));

obs_log(gf->log_level, "Create whisper context");
gf->whisper_context = init_whisper_context(whisper_model_path, gf);
if (gf->whisper_context == nullptr) {
obs_log(LOG_ERROR, "Failed to initialize whisper context");
Expand Down
Loading