Skip to content

refactor: Add filter-replace-dialog.cpp for filter and replace functi… #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ target_sources(
src/whisper-utils/token-buffer-thread.cpp
src/translation/language_codes.cpp
src/translation/translation.cpp
src/translation/translation-utils.cpp)
src/translation/translation-utils.cpp
src/ui/filter-replace-dialog.cpp)

set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})

Expand Down
1 change: 1 addition & 0 deletions data/locale/en-US.ini
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@ buffered_output_parameters="Buffered output parameters"
buffer_num_lines="Number of lines"
buffer_num_chars_per_line="Amount per line"
buffer_output_type="Output type"
open_filter_ui="Setup Filter and Replace"
19 changes: 5 additions & 14 deletions src/tests/localvocal-offline-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p
gf->target_lang = "";
gf->translation_ctx.add_context = true;
gf->translation_output = "";
gf->suppress_sentences = "";
gf->translate = false;
gf->sentence_psum_accept_thresh = 0.4;

Expand Down Expand Up @@ -251,16 +250,14 @@ void set_text_callback(struct transcription_filter_data *gf,
str_copy = remove_leading_trailing_nonalpha(str_copy);

// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
// split the suppression list by newline into individual sentences
std::vector<std::string> suppress_sentences_list =
split(gf->suppress_sentences, '\n');
if (!gf->filter_words_replace.empty()) {
const std::string original_str_copy = str_copy;
// check if the text is in the suppression list
for (const std::string &suppress_sentence : suppress_sentences_list) {
// if suppress_sentence exists within str_copy, remove it (replace with "")
for (const auto &filter : gf->filter_words_replace) {
// if filter exists within str_copy, remove it (replace with "")
str_copy = std::regex_replace(str_copy,
std::regex(suppress_sentence), "");
std::regex(std::get<0>(filter)),
std::get<1>(filter));
}
if (original_str_copy != str_copy) {
obs_log(LOG_INFO, "Suppression: '%s' -> '%s'",
Expand Down Expand Up @@ -378,12 +375,6 @@ int wmain(int argc, wchar_t *argv[])
config["fix_utf8"] ? "true" : "false");
gf->fix_utf8 = config["fix_utf8"];
}
if (config.contains("suppress_sentences")) {
obs_log(LOG_INFO, "Setting suppress_sentences to %s",
config["suppress_sentences"].get<std::string>().c_str());
gf->suppress_sentences =
config["suppress_sentences"].get<std::string>();
}
if (config.contains("enable_audio_chunks_callback")) {
obs_log(LOG_INFO, "Setting enable_audio_chunks_callback to %s",
config["enable_audio_chunks_callback"] ? "true" : "false");
Expand Down
13 changes: 6 additions & 7 deletions src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,14 @@ void set_text_callback(struct transcription_filter_data *gf,
}

// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
// split the suppression list by newline into individual sentences
std::vector<std::string> suppress_sentences_list =
split(gf->suppress_sentences, '\n');
if (!gf->filter_words_replace.empty()) {
const std::string original_str_copy = str_copy;
// check if the text is in the suppression list
for (const std::string &suppress_sentence : suppress_sentences_list) {
// if suppress_sentence exists within str_copy, remove it (replace with "")
str_copy = std::regex_replace(str_copy, std::regex(suppress_sentence), "");
for (const auto &filter_words : gf->filter_words_replace) {
// if filter exists within str_copy, replace it with the replacement
str_copy = std::regex_replace(str_copy,
std::regex(std::get<0>(filter_words)),
std::get<1>(filter_words));
}
// if the text was modified, log the original and modified text
if (original_str_copy != str_copy) {
Expand Down
2 changes: 1 addition & 1 deletion src/transcription-filter-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct transcription_filter_data {
std::string target_lang;
std::string translation_output;
bool enable_token_ts_dtw = false;
std::string suppress_sentences;
std::vector<std::tuple<std::string, std::string>> filter_words_replace;
bool fix_utf8 = true;
bool enable_audio_chunks_callback = false;
bool source_signals_set = false;
Expand Down
31 changes: 25 additions & 6 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "translation/translation-utils.h"
#include "translation/translation.h"
#include "translation/translation-includes.h"
#include "ui/filter-replace-dialog.h"

void set_source_signals(transcription_filter_data *gf, obs_source_t *parent_source)
{
Expand Down Expand Up @@ -190,6 +191,8 @@ void transcription_filter_update(void *data, obs_data_t *s)
int new_buffer_num_chars_per_line = (int)obs_data_get_int(s, "buffer_num_chars_per_line");
TokenBufferSegmentation new_buffer_output_type =
(TokenBufferSegmentation)obs_data_get_int(s, "buffer_output_type");
gf->filter_words_replace =
deserialize_filter_words_replace(obs_data_get_string(s, "filter_words_replace"));

if (new_buffered_output) {
obs_log(gf->log_level, "buffered_output enable");
Expand Down Expand Up @@ -247,7 +250,6 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->translation_ctx.input_tokenization_style =
(InputTokenizationStyle)obs_data_get_int(s, "translate_input_tokenization_style");
gf->translation_output = obs_data_get_string(s, "translate_output");
gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences");
std::string new_translate_model_index = obs_data_get_string(s, "translate_model");
std::string new_translation_model_path_external =
obs_data_get_string(s, "translation_model_path_external");
Expand Down Expand Up @@ -554,7 +556,6 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_string(s, "translate_model", "whisper-based-translation");
obs_data_set_default_string(s, "translation_model_path_external", "");
obs_data_set_default_int(s, "translate_input_tokenization_style", INPUT_TOKENIZAION_M2M100);
obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT);
obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4);

// translation options
Expand Down Expand Up @@ -841,7 +842,7 @@ obs_properties_t *transcription_filter_properties(void *data)
{"whisper_params_group", "log_words", "caption_to_stream", "buffer_size_msec",
"overlap_size_msec", "step_by_step_processing", "min_sub_duration",
"process_while_muted", "buffered_output", "vad_enabled", "log_level",
"suppress_sentences", "sentence_psum_accept_thresh", "vad_threshold",
"open_filter_ui", "sentence_psum_accept_thresh", "vad_threshold",
"buffered_output_group"}) {
obs_property_set_visible(obs_properties_get(props, prop_name.c_str()),
show_hide);
Expand Down Expand Up @@ -902,9 +903,27 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_list_add_int(list, "INFO", LOG_INFO);
obs_property_list_add_int(list, "WARNING", LOG_WARNING);

// add a text input for sentences to suppress
obs_properties_add_text(ppts, "suppress_sentences", MT_("suppress_sentences"),
OBS_TEXT_MULTILINE);
// add button to open filter and replace UI dialog
obs_properties_add_button2(
ppts, "open_filter_ui", MT_("open_filter_ui"),
[](obs_properties_t *props, obs_property_t *property, void *data_) {
UNUSED_PARAMETER(props);
UNUSED_PARAMETER(property);
struct transcription_filter_data *gf_ =
static_cast<struct transcription_filter_data *>(data_);
FilterReplaceDialog *filter_replace_dialog = new FilterReplaceDialog(
(QWidget *)obs_frontend_get_main_window(), gf_);
filter_replace_dialog->exec();
// store the filter data on the source settings
obs_data_t *settings = obs_source_get_settings(gf_->context);
// serialize the filter data
const std::string filter_data =
serialize_filter_words_replace(gf_->filter_words_replace);
obs_data_set_string(settings, "filter_words_replace", filter_data.c_str());
obs_data_release(settings);
return true;
},
gf);

obs_properties_t *whisper_params_group = obs_properties_create();
obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"),
Expand Down
104 changes: 104 additions & 0 deletions src/ui/filter-replace-dialog.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#include "filter-replace-dialog.h"
#include "ui_filter-replace-dialog.h"

FilterReplaceDialog::FilterReplaceDialog(QWidget *parent, transcription_filter_data *ctx_)
: QDialog(parent),
ctx(ctx_),
ui(new Ui::FilterReplaceDialog)
{
ui->setupUi(this);

// populate the tableWidget with the filter_words_replace map
ui->tableWidget->setRowCount((int)ctx->filter_words_replace.size());
for (size_t i = 0; i < ctx->filter_words_replace.size(); i++) {
const std::string key = std::get<0>(ctx->filter_words_replace[i]);
const std::string value = std::get<1>(ctx->filter_words_replace[i]);
ui->tableWidget->setItem((int)i, 0,
new QTableWidgetItem(QString::fromStdString(key)));
ui->tableWidget->setItem((int)i, 1,
new QTableWidgetItem(QString::fromStdString(value)));
}

// connect toolButton_add
connect(ui->toolButton_add, &QToolButton::clicked, this, &FilterReplaceDialog::addFilter);
// connect toolButton_remove
connect(ui->toolButton_remove, &QToolButton::clicked, this,
&FilterReplaceDialog::removeFilter);
// connect edit triggers
connect(ui->tableWidget, &QTableWidget::itemChanged, this,
&FilterReplaceDialog::editFilter);
}

FilterReplaceDialog::~FilterReplaceDialog()
{
delete ui;
}

void FilterReplaceDialog::addFilter()
{
ui->tableWidget->insertRow(ui->tableWidget->rowCount());
// add an empty filter_words_replace map entry
ctx->filter_words_replace.push_back(std::make_tuple("", ""));
}

void FilterReplaceDialog::removeFilter()
{
if (ui->tableWidget->currentRow() == -1) {
return;
}
ui->tableWidget->removeRow(ui->tableWidget->currentRow());
// remove the filter_words_replace map entry
ctx->filter_words_replace.erase(ctx->filter_words_replace.begin() +
ui->tableWidget->currentRow() + 1);
}

void FilterReplaceDialog::editFilter(QTableWidgetItem *item)
{
if (item->row() >= (int)ctx->filter_words_replace.size()) {
return;
}

std::string key;
if (ui->tableWidget->item(item->row(), 0) == nullptr) {
key = "";
} else {
key = ui->tableWidget->item(item->row(), 0)->text().toStdString();
}
std::string value;
if (ui->tableWidget->item(item->row(), 1) == nullptr) {
value = "";
} else {
value = ui->tableWidget->item(item->row(), 1)->text().toStdString();
}
// use the row number to update the filter_words_replace map
ctx->filter_words_replace[item->row()] = std::make_tuple(key, value);
}

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace)
{
if (filter_words_replace.empty()) {
return "[]";
}
// use JSON to serialize the filter_words_replace map
nlohmann::json j;
for (const auto &entry : filter_words_replace) {
j.push_back({{"key", std::get<0>(entry)}, {"value", std::get<1>(entry)}});
}
return j.dump();
}

std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str)
{
if (filter_words_replace_str.empty()) {
return {};
}
// use JSON to deserialize the filter_words_replace map
std::vector<std::tuple<std::string, std::string>> filter_words_replace;
nlohmann::json j = nlohmann::json::parse(filter_words_replace_str);
for (const auto &entry : j) {
filter_words_replace.push_back(std::make_tuple(entry["key"], entry["value"]));
}
return filter_words_replace;
}
35 changes: 35 additions & 0 deletions src/ui/filter-replace-dialog.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef FILTERREPLACEDIALOG_H
#define FILTERREPLACEDIALOG_H

#include <QDialog>
#include <QTableWidgetItem>

#include "transcription-filter-data.h"

namespace Ui {
class FilterReplaceDialog;
}

class FilterReplaceDialog : public QDialog {
Q_OBJECT

public:
explicit FilterReplaceDialog(QWidget *parent, transcription_filter_data *ctx_);
~FilterReplaceDialog();

private:
Ui::FilterReplaceDialog *ui;
transcription_filter_data *ctx;

private slots:
void addFilter();
void removeFilter();
void editFilter(QTableWidgetItem *item);
};

std::string serialize_filter_words_replace(
const std::vector<std::tuple<std::string, std::string>> &filter_words_replace);
std::vector<std::tuple<std::string, std::string>>
deserialize_filter_words_replace(const std::string &filter_words_replace_str);

#endif // FILTERREPLACEDIALOG_H
Loading
Loading