Skip to content

Commit c302d3a

Browse files
committed
feat: Add translation language utilities
This commit adds a new file, `translation-language-utils.h`, which contains utility functions for handling translation languages. The `remove_start_punctuation` function removes any leading punctuation from a given string. This utility will be used in the translation process to improve the quality of the translated output.
1 parent 0e3df02 commit c302d3a

10 files changed

+190
-8
lines changed

CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ else()
9696
include(cmake/FetchOnnxruntime.cmake)
9797
endif()
9898

99+
include(cmake/BuildICU.cmake)
100+
# Add ICU to the target
101+
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ICU::ICU)
102+
target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC ${ICU_INCLUDE_DIR})
103+
99104
target_sources(
100105
${CMAKE_PROJECT_NAME}
101106
PRIVATE src/plugin-main.c
@@ -118,6 +123,7 @@ target_sources(
118123
src/translation/language_codes.cpp
119124
src/translation/translation.cpp
120125
src/translation/translation-utils.cpp
126+
src/translation/translation-language-utils.cpp
121127
src/ui/filter-replace-dialog.cpp)
122128

123129
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})
@@ -140,11 +146,12 @@ if(ENABLE_TESTS)
140146
src/whisper-utils/token-buffer-thread.cpp
141147
src/whisper-utils/vad-processing.cpp
142148
src/translation/language_codes.cpp
143-
src/translation/translation.cpp)
149+
src/translation/translation.cpp
150+
src/translation/translation-language-utils.cpp)
144151

145152
find_libav(${CMAKE_PROJECT_NAME}-tests)
146153

147-
target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs)
154+
target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs ICU::ICU)
148155
target_include_directories(${CMAKE_PROJECT_NAME}-tests PRIVATE src)
149156

150157
# install the tests to the release/test directory

cmake/BuildICU.cmake

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
include(FetchContent)
2+
include(ExternalProject)
3+
4+
set(ICU_VERSION "75.1")
5+
set(ICU_VERSION_UNDERSCORE "75_1")
6+
set(ICU_VERSION_DASH "75-1")
7+
set(ICU_VERSION_NO_MINOR "75")
8+
9+
if(WIN32)
10+
set(ICU_URL
11+
"https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_DASH}/icu4c-${ICU_VERSION_UNDERSCORE}-Win64-MSVC2022.zip"
12+
)
13+
set(ICU_HASH "SHA256=7ac9c0dc6ccc1ec809c7d5689b8d831c5b8f6b11ecf70fdccc55f7ae8731ac8f")
14+
15+
FetchContent_Declare(
16+
ICU
17+
URL ${ICU_URL}
18+
URL_HASH ${ICU_HASH})
19+
20+
FetchContent_MakeAvailable(ICU)
21+
22+
# Assuming the ZIP structure, adjust paths as necessary
23+
set(ICU_INCLUDE_DIR "${icu_SOURCE_DIR}/include")
24+
set(ICU_LIBRARY_DIR "${icu_SOURCE_DIR}/lib64")
25+
set(ICU_BINARY_DIR "${icu_SOURCE_DIR}/bin64")
26+
27+
# Add ICU libraries
28+
find_library(
29+
ICU_DATA_LIBRARY
30+
NAMES icudt
31+
PATHS ${ICU_LIBRARY_DIR}
32+
NO_DEFAULT_PATH)
33+
find_library(
34+
ICU_UC_LIBRARY
35+
NAMES icuuc
36+
PATHS ${ICU_LIBRARY_DIR}
37+
NO_DEFAULT_PATH)
38+
find_library(
39+
ICU_IN_LIBRARY
40+
NAMES icuin
41+
PATHS ${ICU_LIBRARY_DIR}
42+
NO_DEFAULT_PATH)
43+
44+
# find the dlls
45+
find_file(
46+
ICU_DATA_DLL
47+
NAMES icudt${ICU_VERSION_NO_MINOR}.dll
48+
PATHS ${ICU_BINARY_DIR}
49+
NO_DEFAULT_PATH)
50+
find_file(
51+
ICU_UC_DLL
52+
NAMES icuuc${ICU_VERSION_NO_MINOR}.dll
53+
PATHS ${ICU_BINARY_DIR}
54+
NO_DEFAULT_PATH)
55+
find_file(
56+
ICU_IN_DLL
57+
NAMES icuin${ICU_VERSION_NO_MINOR}.dll
58+
PATHS ${ICU_BINARY_DIR}
59+
NO_DEFAULT_PATH)
60+
61+
# Copy the DLLs to the output directory
62+
install(FILES ${ICU_DATA_DLL} DESTINATION "obs-plugins/64bit")
63+
install(FILES ${ICU_UC_DLL} DESTINATION "obs-plugins/64bit")
64+
install(FILES ${ICU_IN_DLL} DESTINATION "obs-plugins/64bit")
65+
66+
else() # Mac and Linux
67+
set(ICU_URL
68+
"https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_UNDERSCORE}/icu4c-${ICU_VERSION_UNDERSCORE}-src.tgz"
69+
)
70+
set(ICU_HASH "SHA256=94bb97d88f13bb74ec0168446a845511bd92c1c49ee8e63df646a48c38dfde6d")
71+
72+
set(ICU_INSTALL_DIR "${CMAKE_BINARY_DIR}/icu-install")
73+
74+
ExternalProject_Add(
75+
ICU
76+
URL ${ICU_URL}
77+
URL_HASH ${ICU_HASH}
78+
CONFIGURE_COMMAND <SOURCE_DIR>/source/runConfigureICU Linux --prefix=${ICU_INSTALL_DIR}
79+
BUILD_COMMAND make -j4
80+
INSTALL_COMMAND make install
81+
BUILD_IN_SOURCE 1)
82+
83+
set(ICU_INCLUDE_DIR "${ICU_INSTALL_DIR}/include")
84+
set(ICU_LIBRARY_DIR "${ICU_INSTALL_DIR}/lib")
85+
86+
# Add ICU libraries
87+
find_library(
88+
ICU_DATA_LIBRARY
89+
NAMES icudata
90+
PATHS ${ICU_LIBRARY_DIR}
91+
NO_DEFAULT_PATH)
92+
find_library(
93+
ICU_UC_LIBRARY
94+
NAMES icuuc
95+
PATHS ${ICU_LIBRARY_DIR}
96+
NO_DEFAULT_PATH)
97+
find_library(
98+
ICU_I18N_LIBRARY
99+
NAMES icui18n
100+
PATHS ${ICU_LIBRARY_DIR}
101+
NO_DEFAULT_PATH)
102+
endif()
103+
104+
# Create an interface target for ICU
105+
add_library(ICU::ICU INTERFACE IMPORTED GLOBAL)
106+
target_include_directories(ICU::ICU INTERFACE ${ICU_INCLUDE_DIR})
107+
target_link_libraries(ICU::ICU INTERFACE ${ICU_DATA_LIBRARY} ${ICU_UC_LIBRARY}
108+
$<IF:$<BOOL:${WIN32}>,${ICU_IN_LIBRARY},${ICU_I18N_LIBRARY}>)
109+
110+
if(NOT WIN32)
111+
add_dependencies(ICU::ICU ICU)
112+
endif()

src/transcription-filter-callbacks.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,11 @@ void set_text_callback(struct transcription_filter_data *gf,
281281
gf->cleared_last_sub = false;
282282
if (result.result == DETECTION_RESULT_SPEECH) {
283283
// save the last subtitle if it was a full sentence
284-
gf->last_transcription_sentence = result.text;
284+
gf->last_transcription_sentence.push_back(result.text);
285+
// remove the oldest sentence if the buffer is too long
286+
while (gf->last_transcription_sentence.size() > gf->n_context_sentences) {
287+
gf->last_transcription_sentence.pop_front();
288+
}
285289
}
286290
}
287291
};
@@ -330,7 +334,7 @@ void reset_caption_state(transcription_filter_data *gf_)
330334
gf_->last_text_translation = "";
331335
gf_->translation_ctx.last_input_tokens.clear();
332336
gf_->translation_ctx.last_translation_tokens.clear();
333-
gf_->last_transcription_sentence = "";
337+
gf_->last_transcription_sentence.clear();
334338
// flush the buffer
335339
{
336340
std::lock_guard<std::mutex> lock(gf_->whisper_buf_mutex);

src/transcription-filter-data.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ struct transcription_filter_data {
9393
std::string last_text_for_translation;
9494
std::string last_text_translation;
9595

96-
std::string last_transcription_sentence;
96+
// Transcription context sentences
97+
int n_context_sentences;
98+
std::deque<std::string> last_transcription_sentence;
9799

98100
// Text source to output the subtitles
99101
std::string text_source_name;

src/transcription-filter-properties.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ void add_whisper_params_group_properties(obs_properties_t *ppts)
414414
WHISPER_SAMPLING_BEAM_SEARCH);
415415
obs_property_list_add_int(whisper_sampling_method_list, "Greedy", WHISPER_SAMPLING_GREEDY);
416416

417+
// add int slider for context sentences
418+
obs_properties_add_int_slider(whisper_params_group, "n_context_sentences",
419+
MT_("n_context_sentences"), 0, 5, 1);
420+
417421
// int n_threads;
418422
obs_properties_add_int_slider(whisper_params_group, "n_threads", MT_("n_threads"), 1, 8, 1);
419423
// int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
@@ -600,6 +604,7 @@ void transcription_filter_defaults(obs_data_t *s)
600604

601605
// Whisper parameters
602606
obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH);
607+
obs_data_set_default_int(s, "n_context_sentences", 0);
603608
obs_data_set_default_string(s, "initial_prompt", "");
604609
obs_data_set_default_int(s, "n_threads", 4);
605610
obs_data_set_default_int(s, "n_max_text_ctx", 16384);

src/transcription-filter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ void transcription_filter_update(void *data, obs_data_t *s)
346346
{
347347
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
348348

349+
gf->n_context_sentences = (int)obs_data_get_int(s, "n_context_sentences");
350+
349351
gf->sentence_psum_accept_thresh =
350352
(float)obs_data_get_double(s, "sentence_psum_accept_thresh");
351353

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "translation-language-utils.h"
2+
3+
#include <unicode/unistr.h>
4+
#include <unicode/uchar.h>
5+
6+
std::string remove_start_punctuation(const std::string &text)
7+
{
8+
if (text.empty()) {
9+
return text;
10+
}
11+
12+
// Convert the input string to ICU's UnicodeString
13+
icu::UnicodeString ustr = icu::UnicodeString::fromUTF8(text);
14+
15+
// Find the index of the first non-punctuation character
16+
int32_t start = 0;
17+
while (start < ustr.length()) {
18+
UChar32 ch = ustr.char32At(start);
19+
if (!u_ispunct(ch)) {
20+
break;
21+
}
22+
start += U16_LENGTH(ch);
23+
}
24+
25+
// Create a new UnicodeString with punctuation removed from the start
26+
icu::UnicodeString result = ustr.tempSubString(start);
27+
28+
// Convert the result back to UTF-8
29+
std::string output;
30+
result.toUTF8String(output);
31+
32+
return output;
33+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef TRANSLATION_LANGUAGE_UTILS_H
2+
#define TRANSLATION_LANGUAGE_UTILS_H
3+
4+
#include <string>
5+
6+
std::string remove_start_punctuation(const std::string &text);
7+
8+
#endif // TRANSLATION_LANGUAGE_UTILS_H

src/translation/translation.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "model-utils/model-find-utils.h"
44
#include "transcription-filter-data.h"
55
#include "language_codes.h"
6+
#include "translation-language-utils.h"
67

78
#include <ctranslate2/translator.h>
89
#include <sentencepiece_processor.h>
@@ -201,7 +202,8 @@ int translate(struct translation_context &translation_ctx, const std::string &te
201202
(int)translation_ctx.last_translation_tokens.size());
202203

203204
// detokenize
204-
result = translation_ctx.detokenizer(translation_tokens);
205+
const std::string result_ = translation_ctx.detokenizer(translation_tokens);
206+
result = remove_start_punctuation(result_);
205207
} catch (std::exception &e) {
206208
obs_log(LOG_ERROR, "Error: %s", e.what());
207209
return OBS_POLYGLOT_TRANSLATION_FAIL;

src/whisper-utils/whisper-processing.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,15 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
181181
return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""};
182182
}
183183

184-
// obs_log(LOG_INFO, "initial prompt: %s", gf->last_transcription_sentence.c_str());
185-
// gf->whisper_params.initial_prompt = gf->last_transcription_sentence.c_str();
184+
if (gf->n_context_sentences > 0 && !gf->last_transcription_sentence.empty()) {
185+
// set the initial prompt to the last transcription sentences (concatenated)
186+
std::string initial_prompt = gf->last_transcription_sentence[0];
187+
for (int i = 1; i < gf->last_transcription_sentence.size(); ++i) {
188+
initial_prompt += " " + gf->last_transcription_sentence[i];
189+
}
190+
gf->whisper_params.initial_prompt = initial_prompt.c_str();
191+
obs_log(gf->log_level, "Initial prompt: %s", gf->whisper_params.initial_prompt);
192+
}
186193

187194
// run the inference
188195
int whisper_full_result = -1;

0 commit comments

Comments
 (0)