Skip to content

Commit

Permalink
Better handling of case differences for non-ascii inputs (#24)
Browse files Browse the repository at this point in the history
* Starting with tests with capitalized non-ascii characters

* Replacing all tolower calls with ICU implementation

* Fixing broken build with utility changes
  • Loading branch information
qmac authored Feb 17, 2022
1 parent 1c9fd9e commit f14d9ff
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 32 deletions.
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.5)
cmake_minimum_required(VERSION 3.7)

project(fstalign LANGUAGES CXX C)

Expand Down Expand Up @@ -69,10 +69,14 @@ add_library(fstaligner-common
third-party/inih/cpp/INIReader.cpp
)

list(APPEND CMAKE_PREFIX_PATH "/usr/local/opt/icu4c") # for Mac users
find_package(ICU REQUIRED COMPONENTS uc)

target_link_libraries(fstaligner-common
Threads::Threads
${FSTALIGN_LIBRARIES}
${FST_KALDI_LIBRARIES}
${ICU_LIBRARIES}
)

add_subdirectory(third-party/jsoncpp)
Expand All @@ -83,6 +87,7 @@ add_executable(fstalign src/main.cpp)
include_directories(fstalign
${FSTALIGN_INCLUDES}
${OPENFST_INCLUDES}
${ICU_INCLUDE_DIRS}
)

target_link_libraries(fstalign
Expand Down
7 changes: 4 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ ARG JOBS=4
RUN apt-get update && \
apt-get -y install \
cmake \
g++
g++ \
libicu-dev

RUN mkdir /fstalign
COPY CMakeLists.txt /fstalign/CMakeLists.txt
Expand All @@ -34,5 +35,5 @@ RUN mkdir -p /fstalign/build && \
COPY tools /fstalign/tools

ENV PATH \
/fstalign/bin/:\
$PATH
/fstalign/bin/:\
$PATH
11 changes: 5 additions & 6 deletions src/Ctm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ CtmFstLoader::CtmFstLoader(vector<RawCtmRecord> &records) : FstLoader() {
{
mCtmRows = records;
for (auto &row : mCtmRows) {
string token = string(row.word);
std::transform(token.begin(), token.end(), token.begin(), ::tolower);
mToken.push_back(token);
std::string lower_cased = UnicodeLowercase(row.word);
mToken.push_back(lower_cased);
}
}
}
Expand Down Expand Up @@ -52,13 +51,13 @@ StdVectorFst CtmFstLoader::convertToFst(const SymbolTable &symbol, std::vector<i
int map_sz = map.size();
for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) {
std::string token = *i;
std::transform(token.begin(), token.end(), token.begin(), ::tolower);
std::string lower_cased = UnicodeLowercase(token);
transducer.AddState();

if (map_sz > wc && map[wc] > 0) {
transducer.AddArc(prevState, StdArc(symbol.Find(token), symbol.Find(token), 1.0f, nextState));
transducer.AddArc(prevState, StdArc(symbol.Find(lower_cased), symbol.Find(lower_cased), 1.0f, nextState));
} else {
transducer.AddArc(prevState, StdArc(symbol.Find(token), symbol.Find(token), 0.0f, nextState));
transducer.AddArc(prevState, StdArc(symbol.Find(lower_cased), symbol.Find(lower_cased), 0.0f, nextState));
}

prevState = nextState;
Expand Down
13 changes: 7 additions & 6 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <fstream>

#include <csv/csv.h>
#include "utilities.h"

/***********************************
NLP FstLoader class start
Expand Down Expand Up @@ -64,8 +65,8 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
mJsonNorm[curr_label_id]["candidates"][last_idx]["verbalization"].append(curr_tk);
}
} else {
std::transform(curr_tk.begin(), curr_tk.end(), curr_tk.begin(), ::tolower);
mToken.push_back(curr_tk);
std::string lower_cased = UnicodeLowercase(curr_tk);
mToken.push_back(lower_cased);
mSpeakers.push_back(speaker);
}

Expand Down Expand Up @@ -93,8 +94,8 @@ void NlpFstLoader::addToSymbolTable(fst::SymbolTable &symbol) const {
auto candidate = candidates[i]["verbalization"];
for (auto tk_itr : candidate) {
std::string token = tk_itr.asString();
std::transform(token.begin(), token.end(), token.begin(), ::tolower);
AddSymbolIfNeeded(symbol, token);
std::string lower_cased = UnicodeLowercase(token);
AddSymbolIfNeeded(symbol, lower_cased);
}
}
}
Expand Down Expand Up @@ -225,11 +226,11 @@ so we add 2 states
auto candidate = candidates[i]["verbalization"];
for (auto tk_itr : candidate) {
std::string ltoken = std::string(tk_itr.asString());
std::transform(ltoken.begin(), ltoken.end(), ltoken.begin(), ::tolower);
std::string lower_cased = UnicodeLowercase(ltoken);
transducer.AddState();
nextState++;

int token_sym = symbol.Find(ltoken);
int token_sym = symbol.Find(lower_cased);
if (token_sym == -1) {
token_sym = symbol.Find(options.symUnk);
}
Expand Down
6 changes: 2 additions & 4 deletions src/OneBestFstLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ void OneBestFstLoader::LoadTextFile(const std::string filename) {

void OneBestFstLoader::addToSymbolTable(fst::SymbolTable &symbol) const {
for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) {
std::string token = *i;
std::transform(token.begin(), token.end(), token.begin(), ::tolower);
std::string token = UnicodeLowercase(*i);
// fst::kNoSymbol
if (symbol.Find(token) == -1) {
symbol.AddSymbol(token);
Expand All @@ -58,8 +57,7 @@ fst::StdVectorFst OneBestFstLoader::convertToFst(const fst::SymbolTable &symbol,
int map_sz = map.size();
int wc = 0;
for (TokenType::const_iterator i = mToken.begin(); i != mToken.end(); ++i) {
std::string token = *i;
std::transform(token.begin(), token.end(), token.begin(), ::tolower);
std::string token = UnicodeLowercase(*i);
transducer.AddState();

int tk_idx = symbol.Find(token);
Expand Down
6 changes: 2 additions & 4 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ vector<shared_ptr<Stitching>> make_stitches(spWERA alignment, vector<RawCtmRecor

part->hyp_orig = ctmPart.word;
// sanity check
std::string ctmCopy = ctmPart.word;
std::transform(ctmCopy.begin(), ctmCopy.end(), ctmCopy.begin(), ::tolower);
std::string ctmCopy = UnicodeLowercase(ctmPart.word);
if (hyp_tk != ctmCopy) {
logger->warn(
"hum, looks like the ctm and the alignment got out of sync? [{}] vs "
Expand Down Expand Up @@ -329,8 +328,7 @@ vector<shared_ptr<Stitching>> make_stitches(spWERA alignment, vector<RawCtmRecor
part->hyp_orig = token;

// sanity check
std::string token_copy = token;
std::transform(token_copy.begin(), token_copy.end(), token_copy.begin(), ::tolower);
std::string token_copy = UnicodeLowercase(token);
if (hyp_tk != token_copy) {
logger->warn(
"hum, looks like the text and the alignment got out of sync? [{}] vs "
Expand Down
15 changes: 15 additions & 0 deletions src/utilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,18 @@ string GetClassLabel(string best_label) {
std::replace(classlabel.begin(), classlabel.end(), ':', '_');
return classlabel;
}

string UnicodeLowercase(string token) {
icu::UnicodeString utoken = icu::UnicodeString::fromUTF8(token);
std::string lower_cased;
utoken.toLower().toUTF8String(lower_cased);
return lower_cased;
}

bool EndsWithCaseInsensitive(const string &value, const string &ending) {
if (ending.size() > value.size()) {
return false;
}
return equal(ending.rbegin(), ending.rend(), value.rbegin(),
[](const char a, const char b) { return tolower(a) == tolower(b); });
}
14 changes: 6 additions & 8 deletions src/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#ifndef UTILITIES_H_
#define UTILITIES_H_

#include <unicode/locid.h>
#include <unicode/unistr.h>
#include <unicode/ustream.h>
#include <algorithm>
#include <cctype>
#include <codecvt>
Expand Down Expand Up @@ -208,14 +211,7 @@ void printFst(string loggerName, const fst::StdFst *fst, const fst::SymbolTable
template <typename StringFunction>
void splitString(const string &str, char delimiter, StringFunction f);

static bool EndsWithCaseInsensitive(const string &value, const string &ending) {
if (ending.size() > value.size()) {
return false;
}
return equal(ending.rbegin(), ending.rend(), value.rbegin(),
[](const char a, const char b) { return tolower(a) == tolower(b); });
}

bool EndsWithCaseInsensitive(const string &value, const string &ending);
bool iequals(const std::string &, const std::string &);

// string manip
Expand All @@ -242,4 +238,6 @@ string GetLabelNameFromClassLabel(string classLabel);

string GetClassLabel(string best_label);

string UnicodeLowercase(string token);

#endif // UTILITIES_H_
1 change: 1 addition & 0 deletions test/data/wer_utf.hyp.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ça va va bien aujourd'hui éte inutile Êtes
1 change: 1 addition & 0 deletions test/data/wer_utf.ref.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ça va bien aujourd'hui étÉ inutile êtes
8 changes: 8 additions & 0 deletions test/fstalign_Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,14 @@ TEST_CASE_METHOD(UniqueTestsFixture, "main-adapted-composition()") {
REQUIRE_THAT(result, Contains("Wer Entity ID 3 WER: 1/3 = 0.3333"));
}

SECTION("wer_utf wer") {
const auto result =
exec(command("wer", approach, "wer_utf.ref.txt", "wer_utf.hyp.txt", sbs_output, "", TEST_SYNONYMS));

REQUIRE_THAT(result, Contains("WER: 2/7 = 0.2857"));
REQUIRE_THAT(result, Contains("WER: INS:1 DEL:0 SUB:1"));
}

// Additional WER tests
SECTION("entity precision recall") {
const auto testFile = std::string{TEST_DATA} + "twenty.hyp-a2.sbs";
Expand Down

0 comments on commit f14d9ff

Please sign in to comment.