From f14d9ff1a0fd8fa1c18ea3bf86f9046e3d8c0f3f Mon Sep 17 00:00:00 2001 From: Quinn McNamara Date: Thu, 17 Feb 2022 00:47:27 -0600 Subject: [PATCH] Better handling of case differences for non-ascii inputs (#24) * Starting with tests with capitalized non-ascii characters * Replacing all tolower calls with ICU implementation * Fixing broken build with utility changes --- CMakeLists.txt | 7 ++++++- Dockerfile | 7 ++++--- src/Ctm.cpp | 11 +++++------ src/Nlp.cpp | 13 +++++++------ src/OneBestFstLoader.cpp | 6 ++---- src/fstalign.cpp | 6 ++---- src/utilities.cpp | 15 +++++++++++++++ src/utilities.h | 14 ++++++-------- test/data/wer_utf.hyp.txt | 1 + test/data/wer_utf.ref.txt | 1 + test/fstalign_Test.cc | 8 ++++++++ 11 files changed, 57 insertions(+), 32 deletions(-) create mode 100644 test/data/wer_utf.hyp.txt create mode 100644 test/data/wer_utf.ref.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index eae9329..3048cc5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.7) project(fstalign LANGUAGES CXX C) @@ -69,10 +69,14 @@ add_library(fstaligner-common third-party/inih/cpp/INIReader.cpp ) +list(APPEND CMAKE_PREFIX_PATH "/usr/local/opt/icu4c") # for Mac users +find_package(ICU REQUIRED COMPONENTS uc) + target_link_libraries(fstaligner-common Threads::Threads ${FSTALIGN_LIBRARIES} ${FST_KALDI_LIBRARIES} + ${ICU_LIBRARIES} ) add_subdirectory(third-party/jsoncpp) @@ -83,6 +87,7 @@ add_executable(fstalign src/main.cpp) include_directories(fstalign ${FSTALIGN_INCLUDES} ${OPENFST_INCLUDES} + ${ICU_INCLUDE_DIRS} ) target_link_libraries(fstalign diff --git a/Dockerfile b/Dockerfile index 7933cfc..c40e893 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,8 @@ ARG JOBS=4 RUN apt-get update && \ apt-get -y install \ cmake \ - g++ + g++ \ + libicu-dev RUN mkdir /fstalign COPY CMakeLists.txt /fstalign/CMakeLists.txt @@ -34,5 +35,5 @@ RUN mkdir -p /fstalign/build && \ COPY tools /fstalign/tools ENV PATH \ -/fstalign/bin/:\ -$PATH + /fstalign/bin/:\ + $PATH diff --git a/src/Ctm.cpp b/src/Ctm.cpp index fd12457..6353b19 100644 --- a/src/Ctm.cpp +++ b/src/Ctm.cpp @@ -21,9 +21,8 @@ CtmFstLoader::CtmFstLoader(vector &records) : FstLoader() { { mCtmRows = records; for (auto &row : mCtmRows) { - string token = string(row.word); - std::transform(token.begin(), token.end(), token.begin(), ::tolower); - mToken.push_back(token); + std::string lower_cased = UnicodeLowercase(row.word); + mToken.push_back(lower_cased); } } } @@ -52,13 +51,13 @@ StdVectorFst CtmFstLoader::convertToFst(const SymbolTable &symbol, std::vector wc && map[wc] > 0) { - transducer.AddArc(prevState, StdArc(symbol.Find(token), symbol.Find(token), 1.0f, nextState)); + transducer.AddArc(prevState, StdArc(symbol.Find(lower_cased), symbol.Find(lower_cased), 1.0f, nextState)); } else { - transducer.AddArc(prevState, StdArc(symbol.Find(token), symbol.Find(token), 0.0f, nextState)); + transducer.AddArc(prevState, StdArc(symbol.Find(lower_cased), symbol.Find(lower_cased), 0.0f, nextState)); } prevState = nextState; diff --git a/src/Nlp.cpp b/src/Nlp.cpp index b120e64..c6fc40e 100644 --- a/src/Nlp.cpp +++ b/src/Nlp.cpp @@ -10,6 +10,7 @@ #include #include +#include "utilities.h" /*********************************** NLP FstLoader class start @@ -64,8 +65,8 @@ NlpFstLoader::NlpFstLoader(std::vector &records, Json::Value norma mJsonNorm[curr_label_id]["candidates"][last_idx]["verbalization"].append(curr_tk); } } else { - std::transform(curr_tk.begin(), curr_tk.end(), curr_tk.begin(), ::tolower); - mToken.push_back(curr_tk); + std::string lower_cased = UnicodeLowercase(curr_tk); + mToken.push_back(lower_cased); mSpeakers.push_back(speaker); } @@ -93,8 +94,8 @@ void NlpFstLoader::addToSymbolTable(fst::SymbolTable &symbol) const { auto candidate = candidates[i]["verbalization"]; for (auto tk_itr : candidate) { std::string token = tk_itr.asString(); - std::transform(token.begin(), token.end(), token.begin(), ::tolower); - AddSymbolIfNeeded(symbol, token); + std::string lower_cased = UnicodeLowercase(token); + AddSymbolIfNeeded(symbol, lower_cased); } } } @@ -225,11 +226,11 @@ so we add 2 states auto candidate = candidates[i]["verbalization"]; for (auto tk_itr : candidate) { std::string ltoken = std::string(tk_itr.asString()); - std::transform(ltoken.begin(), ltoken.end(), ltoken.begin(), ::tolower); + std::string lower_cased = UnicodeLowercase(ltoken); transducer.AddState(); nextState++; - int token_sym = symbol.Find(ltoken); + int token_sym = symbol.Find(lower_cased); if (token_sym == -1) { token_sym = symbol.Find(options.symUnk); } diff --git a/src/OneBestFstLoader.cpp b/src/OneBestFstLoader.cpp index 87c57f3..0e6cc75 100644 --- a/src/OneBestFstLoader.cpp +++ b/src/OneBestFstLoader.cpp @@ -33,8 +33,7 @@ void OneBestFstLoader::LoadTextFile(const std::string filename) { void OneBestFstLoader::addToSymbolTable(fst::SymbolTable &symbol) const { for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) { - std::string token = *i; - std::transform(token.begin(), token.end(), token.begin(), ::tolower); + std::string token = UnicodeLowercase(*i); // fst::kNoSymbol if (symbol.Find(token) == -1) { symbol.AddSymbol(token); @@ -58,8 +57,7 @@ fst::StdVectorFst OneBestFstLoader::convertToFst(const fst::SymbolTable &symbol, int map_sz = map.size(); int wc = 0; for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) { - std::string token = *i; - std::transform(token.begin(), token.end(), token.begin(), ::tolower); + std::string token = UnicodeLowercase(*i); transducer.AddState(); int tk_idx = symbol.Find(token); diff --git a/src/fstalign.cpp b/src/fstalign.cpp index 5747b3d..042ef19 100644 --- a/src/fstalign.cpp +++ b/src/fstalign.cpp @@ -289,8 +289,7 @@ vector> make_stitches(spWERA alignment, vectorhyp_orig = ctmPart.word; // sanity check - std::string ctmCopy = ctmPart.word; - std::transform(ctmCopy.begin(), ctmCopy.end(), ctmCopy.begin(), ::tolower); + std::string ctmCopy = UnicodeLowercase(ctmPart.word); if (hyp_tk != ctmCopy) { logger->warn( "hum, looks like the ctm and the alignment got out of sync? [{}] vs " @@ -329,8 +328,7 @@ vector> make_stitches(spWERA alignment, vectorhyp_orig = token; // sanity check - std::string token_copy = token; - std::transform(token_copy.begin(), token_copy.end(), token_copy.begin(), ::tolower); + std::string token_copy = UnicodeLowercase(token); if (hyp_tk != token_copy) { logger->warn( "hum, looks like the text and the alignment got out of sync? [{}] vs " diff --git a/src/utilities.cpp b/src/utilities.cpp index 775d2a0..e7d6608 100644 --- a/src/utilities.cpp +++ b/src/utilities.cpp @@ -233,3 +233,18 @@ string GetClassLabel(string best_label) { std::replace(classlabel.begin(), classlabel.end(), ':', '_'); return classlabel; } + +string UnicodeLowercase(string token) { + icu::UnicodeString utoken = icu::UnicodeString::fromUTF8(token); + std::string lower_cased; + utoken.toLower().toUTF8String(lower_cased); + return lower_cased; +} + +bool EndsWithCaseInsensitive(const string &value, const string &ending) { + if (ending.size() > value.size()) { + return false; + } + return equal(ending.rbegin(), ending.rend(), value.rbegin(), + [](const char a, const char b) { return tolower(a) == tolower(b); }); +} diff --git a/src/utilities.h b/src/utilities.h index 0b558f5..a7dd9c4 100644 --- a/src/utilities.h +++ b/src/utilities.h @@ -8,6 +8,9 @@ #ifndef UTILITIES_H_ #define UTILITIES_H_ +#include +#include +#include #include #include #include @@ -208,14 +211,7 @@ void printFst(string loggerName, const fst::StdFst *fst, const fst::SymbolTable template void splitString(const string &str, char delimiter, StringFunction f); -static bool EndsWithCaseInsensitive(const string &value, const string &ending) { - if (ending.size() > value.size()) { - return false; - } - return equal(ending.rbegin(), ending.rend(), value.rbegin(), - [](const char a, const char b) { return tolower(a) == tolower(b); }); -} - +bool EndsWithCaseInsensitive(const string &value, const string &ending); bool iequals(const std::string &, const std::string &); // string manip @@ -242,4 +238,6 @@ string GetLabelNameFromClassLabel(string classLabel); string GetClassLabel(string best_label); +string UnicodeLowercase(string token); + #endif // UTILITIES_H_ diff --git a/test/data/wer_utf.hyp.txt b/test/data/wer_utf.hyp.txt new file mode 100644 index 0000000..6e75652 --- /dev/null +++ b/test/data/wer_utf.hyp.txt @@ -0,0 +1 @@ +ça va va bien aujourd'hui éte inutile Êtes \ No newline at end of file diff --git a/test/data/wer_utf.ref.txt b/test/data/wer_utf.ref.txt new file mode 100644 index 0000000..6d6b81b --- /dev/null +++ b/test/data/wer_utf.ref.txt @@ -0,0 +1 @@ +Ça va bien aujourd'hui étÉ inutile êtes \ No newline at end of file diff --git a/test/fstalign_Test.cc b/test/fstalign_Test.cc index 2adb44c..f2f713d 100644 --- a/test/fstalign_Test.cc +++ b/test/fstalign_Test.cc @@ -767,6 +767,14 @@ TEST_CASE_METHOD(UniqueTestsFixture, "main-adapted-composition()") { REQUIRE_THAT(result, Contains("Wer Entity ID 3 WER: 1/3 = 0.3333")); } + SECTION("wer_utf wer") { + const auto result = + exec(command("wer", approach, "wer_utf.ref.txt", "wer_utf.hyp.txt", sbs_output, "", TEST_SYNONYMS)); + + REQUIRE_THAT(result, Contains("WER: 2/7 = 0.2857")); + REQUIRE_THAT(result, Contains("WER: INS:1 DEL:0 SUB:1")); + } + // Additional WER tests SECTION("entity precision recall") { const auto testFile = std::string{TEST_DATA} + "twenty.hyp-a2.sbs";