Skip to content

Commit

Permalink
refactor: Add filter-replace-dialog.cpp for filter and replace functi… (
Browse files Browse the repository at this point in the history
#124)

* refactor: Add filter-replace-dialog.cpp for filter and replace functionality

* refactor: Improve filter-replace-dialog.cpp for filter and replace functionality
  • Loading branch information
royshil authored Jul 2, 2024
1 parent a2244c2 commit 32bbd99
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 30 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ target_sources(
src/whisper-utils/token-buffer-thread.cpp
src/translation/language_codes.cpp
src/translation/translation.cpp
src/translation/translation-utils.cpp)
src/translation/translation-utils.cpp
src/ui/filter-replace-dialog.cpp)

set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})

Expand Down
1 change: 1 addition & 0 deletions data/locale/en-US.ini
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@ buffered_output_parameters="Buffered output parameters"
buffer_num_lines="Number of lines"
buffer_num_chars_per_line="Amount per line"
buffer_output_type="Output type"
open_filter_ui="Setup Filter and Replace"
19 changes: 5 additions & 14 deletions src/tests/localvocal-offline-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p
gf->target_lang = "";
gf->translation_ctx.add_context = true;
gf->translation_output = "";
gf->suppress_sentences = "";
gf->translate = false;
gf->sentence_psum_accept_thresh = 0.4;

Expand Down Expand Up @@ -251,16 +250,14 @@ void set_text_callback(struct transcription_filter_data *gf,
str_copy = remove_leading_trailing_nonalpha(str_copy);

// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
// split the suppression list by newline into individual sentences
std::vector<std::string> suppress_sentences_list =
split(gf->suppress_sentences, '\n');
if (!gf->filter_words_replace.empty()) {
const std::string original_str_copy = str_copy;
// check if the text is in the suppression list
for (const std::string &suppress_sentence : suppress_sentences_list) {
// if suppress_sentence exists within str_copy, remove it (replace with "")
for (const auto &filter : gf->filter_words_replace) {
// if filter exists within str_copy, remove it (replace with "")
str_copy = std::regex_replace(str_copy,
std::regex(suppress_sentence), "");
std::regex(std::get<0>(filter)),
std::get<1>(filter));
}
if (original_str_copy != str_copy) {
obs_log(LOG_INFO, "Suppression: '%s' -> '%s'",
Expand Down Expand Up @@ -378,12 +375,6 @@ int wmain(int argc, wchar_t *argv[])
config["fix_utf8"] ? "true" : "false");
gf->fix_utf8 = config["fix_utf8"];
}
if (config.contains("suppress_sentences")) {
obs_log(LOG_INFO, "Setting suppress_sentences to %s",
config["suppress_sentences"].get<std::string>().c_str());
gf->suppress_sentences =
config["suppress_sentences"].get<std::string>();
}
if (config.contains("enable_audio_chunks_callback")) {
obs_log(LOG_INFO, "Setting enable_audio_chunks_callback to %s",
config["enable_audio_chunks_callback"] ? "true" : "false");
Expand Down
13 changes: 6 additions & 7 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,14 @@ void set_text_callback(struct transcription_filter_data *gf,
}

// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
// split the suppression list by newline into individual sentences
std::vector<std::string> suppress_sentences_list =
split(gf->suppress_sentences, '\n');
if (!gf->filter_words_replace.empty()) {
const std::string original_str_copy = str_copy;
// check if the text is in the suppression list
for (const std::string &suppress_sentence : suppress_sentences_list) {
// if suppress_sentence exists within str_copy, remove it (replace with "")
str_copy = std::regex_replace(str_copy, std::regex(suppress_sentence), "");
for (const auto &filter_words : gf->filter_words_replace) {
// if filter exists within str_copy, replace it with the replacement
str_copy = std::regex_replace(str_copy,
std::regex(std::get<0>(filter_words)),
std::get<1>(filter_words));
}
// if the text was modified, log the original and modified text
if (original_str_copy != str_copy) {
Expand Down
2 changes: 1 addition & 1 deletion src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct transcription_filter_data {
std::string target_lang;
std::string translation_output;
bool enable_token_ts_dtw = false;
std::string suppress_sentences;
std::vector<std::tuple<std::string, std::string>> filter_words_replace;
bool fix_utf8 = true;
bool enable_audio_chunks_callback = false;
bool source_signals_set = false;
Expand Down
31 changes: 25 additions & 6 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "translation/translation-utils.h"
#include "translation/translation.h"
#include "translation/translation-includes.h"
#include "ui/filter-replace-dialog.h"

void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source)
{
Expand Down Expand Up @@ -190,6 +191,8 @@ void transcription_filter_update(void *data, obs_data_t *s)
int new_buffer_num_chars_per_line = (int)obs_data_get_int(s, "buffer_num_chars_per_line");
TokenBufferSegmentation new_buffer_output_type =
(TokenBufferSegmentation)obs_data_get_int(s, "buffer_output_type");
gf->filter_words_replace =
deserialize_filter_words_replace(obs_data_get_string(s, "filter_words_replace"));

if (new_buffered_output) {
obs_log(gf->log_level, "buffered_output enable");
Expand Down Expand Up @@ -247,7 +250,6 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->translation_ctx.input_tokenization_style =
(InputTokenizationStyle)obs_data_get_int(s, "translate_input_tokenization_style");
gf->translation_output = obs_data_get_string(s, "translate_output");
gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences");
std::string new_translate_model_index = obs_data_get_string(s, "translate_model");
std::string new_translation_model_path_external =
obs_data_get_string(s, "translation_model_path_external");
Expand Down Expand Up @@ -554,7 +556,6 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_string(s, "translate_model", "whisper-based-translation");
obs_data_set_default_string(s, "translation_model_path_external", "");
obs_data_set_default_int(s, "translate_input_tokenization_style", INPUT_TOKENIZAION_M2M100);
obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT);
obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4);

// translation options
Expand Down Expand Up @@ -841,7 +842,7 @@ obs_properties_t *transcription_filter_properties(void *data)
{"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec",
"overlap_size_msec", "step_by_step_processing", "min_sub_duration",
"process_while_muted", "buffered_output", "vad_enabled", "log_level",
"suppress_sentences", "sentence_psum_accept_thresh", "vad_threshold",
"open_filter_ui", "sentence_psum_accept_thresh", "vad_threshold",
"buffered_output_group"}) {
obs_property_set_visible(obs_properties_get(props, prop_name.c_str()),
show_hide);
Expand Down Expand Up @@ -902,9 +903,27 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_list_add_int(list, "INFO", LOG_INFO);
obs_property_list_add_int(list, "WARNING", LOG_WARNING);

// add a text input for sentences to suppress
obs_properties_add_text(ppts, "suppress_sentences", MT_("suppress_sentences"),
OBS_TEXT_MULTILINE);
// add button to open filter and replace UI dialog
obs_properties_add_button2(
ppts, "open_filter_ui", MT_("open_filter_ui"),
[](obs_properties_t *props, obs_property_t *property, void *data_) {
UNUSED_PARAMETER(props);
UNUSED_PARAMETER(property);
struct transcription_filter_data *gf_ =
static_cast<struct transcription_filter_data *>(data_);
FilterReplaceDialog *filter_replace_dialog = new FilterReplaceDialog(
(QWidget *)obs_frontend_get_main_window(), gf_);
filter_replace_dialog->exec();
// store the filter data on the source settings
obs_data_t *settings = obs_source_get_settings(gf_->context);
// serialize the filter data
const std::string filter_data =
serialize_filter_words_replace(gf_->filter_words_replace);
obs_data_set_string(settings, "filter_words_replace", filter_data.c_str());
obs_data_release(settings);
return true;
},
gf);

obs_properties_t *whisper_params_group = obs_properties_create();
obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"),
Expand Down
104 changes: 104 additions & 0 deletions src/ui/filter-replace-dialog.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#include "filter-replace-dialog.h"
#include "ui_filter-replace-dialog.h"

FilterReplaceDialog::FilterReplaceDialog(QWidget *parent, transcription_filter_data *ctx_)
: QDialog(parent),
ctx(ctx_),
ui(new Ui::FilterReplaceDialog)
{
ui->setupUi(this);

// populate the tableWidget with the filter_words_replace map
ui->tableWidget->setRowCount((int)ctx->filter_words_replace.size());
for (size_t i = 0; i < ctx->filter_words_replace.size(); i++) {
const std::string key = std::get<0>(ctx->filter_words_replace[i]);
const std::string value = std::get<1>(ctx->filter_words_replace[i]);
ui->tableWidget->setItem((int)i, 0,
new QTableWidgetItem(QString::fromStdString(key)));
ui->tableWidget->setItem((int)i, 1,
new QTableWidgetItem(QString::fromStdString(value)));
}

// connect toolButton_add
connect(ui->toolButton_add, &QToolButton::clicked, this, &FilterReplaceDialog::addFilter);
// connect toolButton_remove
connect(ui->toolButton_remove, &QToolButton::clicked, this,
&FilterReplaceDialog::removeFilter);
// connect edit triggers
connect(ui->tableWidget, &QTableWidget::itemChanged, this,
&FilterReplaceDialog::editFilter);
}

FilterReplaceDialog::~FilterReplaceDialog()
{
delete ui;
}

void FilterReplaceDialog::addFilter()
{
ui->tableWidget->insertRow(ui->tableWidget->rowCount());
// add an empty filter_words_replace map entry
ctx->filter_words_replace.push_back(std::make_tuple("", ""));
}

void FilterReplaceDialog::removeFilter()
{
if (ui->tableWidget->currentRow() == -1) {
return;
}
ui->tableWidget->removeRow(ui->tableWidget->currentRow());
// remove the filter_words_replace map entry
ctx->filter_words_replace.erase(ctx->filter_words_replace.begin() +
ui->tableWidget->currentRow() + 1);
}

void FilterReplaceDialog::editFilter(QTableWidgetItem *item)
{
if (item->row() >= (int)ctx->filter_words_replace.size()) {
return;
}

std::string key;
if (ui->tableWidget->item(item->row(), 0) == nullptr) {
key = "";
} else {
key = ui->tableWidget->item(item->row(), 0)->text().toStdString();
}
std::string value;
if (ui->tableWidget->item(item->row(), 1) == nullptr) {
value = "";
} else {
value = ui->tableWidget->item(item->row(), 1)->text().toStdString();
}
// use the row number to update the filter_words_replace map
ctx->filter_words_replace[item->row()] = std::make_tuple(key, value);
}

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace)
{
if (filter_words_replace.empty()) {
return "[]";
}
// use JSON to serialize the filter_words_replace map
nlohmann::json j;
for (const auto &entry : filter_words_replace) {
j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}});
}
return j.dump();
}

std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str)
{
if (filter_words_replace_str.empty()) {
return {};
}
// use JSON to deserialize the filter_words_replace map
std::vector<std::tuple<std::string, std::string>> filter_words_replace;
nlohmann::json j = nlohmann::json::parse(filter_words_replace_str);
for (const auto &entry : j) {
filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"]));
}
return filter_words_replace;
}
35 changes: 35 additions & 0 deletions src/ui/filter-replace-dialog.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef FILTERREPLACEDIALOG_H
#define FILTERREPLACEDIALOG_H

#include <QDialog>
#include <QTableWidgetItem>

#include "transcription-filter-data.h"

namespace Ui {
class FilterReplaceDialog;
}

class FilterReplaceDialog : public QDialog {
Q_OBJECT

public:
explicit FilterReplaceDialog(QWidget *parent, transcription_filter_data *ctx_);
~FilterReplaceDialog();

private:
Ui::FilterReplaceDialog *ui;
transcription_filter_data *ctx;

private slots:
void addFilter();
void removeFilter();
void editFilter(QTableWidgetItem *item);
};

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace);
std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str);

#endif // FILTERREPLACEDIALOG_H
Loading

0 comments on commit 32bbd99

Please sign in to comment.