Skip to content

Commit

Permalink
Bug: WER sidecar info not appearing in SBS (#55)
Browse files Browse the repository at this point in the history
* add test

* add test

* fix

* Add and use wer tag data structure

* fix test

* Remove debug log

* remove unigram and bigram info from sbs output

* fix log json missing unigram bigram info if output sbs not set

* version bump
  • Loading branch information
nishchalb authored Apr 18, 2024
1 parent fced1d9 commit 363deb8
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 178 deletions.
18 changes: 8 additions & 10 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,21 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma

// fuse multiple rows that have the same id/label into one entry only
for (auto &row : records) {
mNlpRows.push_back(row);
auto curr_tk = row.token;
auto curr_label = row.best_label;
auto curr_label_id = row.best_label_id;
auto punctuation = row.punctuation;
auto curr_row_tags = row.wer_tags;

// Update wer tags in records to real string labels
vector<string> real_wer_tags;
for (auto &tag : curr_row_tags) {
auto real_tag = tag;
if (mWerSidecar != Json::nullValue) {
real_tag = "###" + real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###";
tag.entity_type = mWerSidecar[tag.tag_id]["entity_type"].asString();
}
real_wer_tags.push_back(real_tag);
}
row.wer_tags = real_wer_tags;
row.wer_tags = curr_row_tags;
std::string speaker = row.speakerId;
mNlpRows.push_back(row);

if (processLabels && curr_label != "") {
if (firstTk || curr_label != last_label) {
Expand Down Expand Up @@ -411,17 +408,18 @@ std::string NlpReader::GetBestLabel(std::string &labels) {
return labels;
}

std::vector<std::string> NlpReader::GetWerTags(std::string &wer_tags_str) {
std::vector<std::string> wer_tags;
std::vector<WerTagEntry> NlpReader::GetWerTags(std::string &wer_tags_str) {
std::vector<WerTagEntry> wer_tags;
if (wer_tags_str == "[]") {
return wer_tags;
}
// wer_tags_str looks like: ['89', '90', '100']
int current_pos = 2;
auto pos = wer_tags_str.find("'", current_pos);
while (pos != -1) {
std::string wer_tag = wer_tags_str.substr(current_pos, pos - current_pos);
wer_tags.push_back(wer_tag);
WerTagEntry entry;
entry.tag_id = wer_tags_str.substr(current_pos, pos - current_pos);
wer_tags.push_back(entry);
current_pos = wer_tags_str.find("'", pos + 1) + 1;
if (current_pos == 0) {
break;
Expand Down
9 changes: 7 additions & 2 deletions src/Nlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
using namespace std;
using namespace fst;

struct WerTagEntry {
string tag_id;
string entity_type;
};

struct RawNlpRecord {
string token;
string speakerId;
Expand All @@ -27,7 +32,7 @@ struct RawNlpRecord {
string labels;
string best_label;
string best_label_id;
vector<string> wer_tags;
vector<WerTagEntry> wer_tags;
string confidence;
};

Expand All @@ -37,7 +42,7 @@ class NlpReader {
virtual ~NlpReader();
vector<RawNlpRecord> read_from_disk(const std::string &filename);
string GetBestLabel(std::string &labels);
vector<string> GetWerTags(std::string &wer_tags_str);
vector<WerTagEntry> GetWerTags(std::string &wer_tags_str);
string GetLabelId(std::string &label);
};

Expand Down
3 changes: 2 additions & 1 deletion src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
<< "[";
/* for (auto wer_tag : nlpRow.wer_tags) { */
for (auto it = stitch.nlpRow.wer_tags.begin(); it != stitch.nlpRow.wer_tags.end(); ++it) {
output_nlp_file << "'" << *it << "'";
output_nlp_file << "'" << it->tag_id << "'";
if (std::next(it) != stitch.nlpRow.wer_tags.end()) {
output_nlp_file << ", ";
}
Expand Down Expand Up @@ -695,6 +695,7 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
}
}

JsonLogUnigramBigramStats(topAlignment);
if (!output_sbs.empty()) {
logger->info("output_sbs = {}", output_sbs);
WriteSbs(topAlignment, stitches, output_sbs);
Expand Down
2 changes: 1 addition & 1 deletion src/version.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#define FSTALIGNER_VERSION_MAJOR 1
#define FSTALIGNER_VERSION_MINOR 12
#define FSTALIGNER_VERSION_MINOR 13
#define FSTALIGNER_VERSION_PATCH 0
41 changes: 10 additions & 31 deletions src/wer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,19 +350,16 @@ void RecordTagWer(const vector<Stitching>& stitches) {
for (const auto &stitch : stitches) {
if (!stitch.nlpRow.wer_tags.empty()) {
for (auto wer_tag : stitch.nlpRow.wer_tags) {
int tag_start = wer_tag.find_first_not_of('#');
int tag_end = wer_tag.find('_');
string wer_tag_id = wer_tag.substr(tag_start, tag_end - tag_start);
wer_results.insert(std::pair<std::string, WerResult>(wer_tag_id, {0, 0, 0, 0, 0}));
wer_results.insert(std::pair<std::string, WerResult>(wer_tag.tag_id, {0, 0, 0, 0, 0}));
// Check with rfind since other comments can be there
bool del = stitch.comment.rfind("del", 0) == 0;
bool ins = stitch.comment.rfind("ins", 0) == 0;
bool sub = stitch.comment.rfind("sub", 0) == 0;
wer_results[wer_tag_id].insertions += ins;
wer_results[wer_tag_id].deletions += del;
wer_results[wer_tag_id].substitutions += sub;
wer_results[wer_tag.tag_id].insertions += ins;
wer_results[wer_tag.tag_id].deletions += del;
wer_results[wer_tag.tag_id].substitutions += sub;
if (!ins) {
wer_results[wer_tag_id].numWordsInReference += 1;
wer_results[wer_tag.tag_id].numWordsInReference += 1;
}
}
}
Expand Down Expand Up @@ -555,7 +552,7 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
string tk_wer_tags = "";
auto wer_tags = p_stitch.nlpRow.wer_tags;
for (auto wer_tag: wer_tags) {
tk_wer_tags = tk_wer_tags + wer_tag + "|";
tk_wer_tags = tk_wer_tags + "###" + wer_tag.tag_id + "_" + wer_tag.entity_type + "###|";
}
string ref_tk = p_stitch.reftk;
string hyp_tk = p_stitch.hyptk;
Expand Down Expand Up @@ -606,6 +603,10 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
myfile << fmt::format("{0:>20}\t{1}", group.first, group.second) << endl;
}

myfile.close();
}

void JsonLogUnigramBigramStats(wer_alignment &topAlignment) {
for (const auto &a : topAlignment.unigram_stats) {
string word = a.first;
gram_error_counter u = a.second;
Expand All @@ -617,18 +618,6 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
jsonLogger::JsonLogger::getLogger().root["wer"]["unigrams"][word]["precision"] = u.precision;
jsonLogger::JsonLogger::getLogger().root["wer"]["unigrams"][word]["recall"] = u.recall;
}
// output error unigrams
myfile << string(60, '-') << endl << fmt::format("{0:>20}\t{1:10}\t{2:10}", "Unigram", "Prec.", "Recall") << endl;
for (const auto &a : topAlignment.unigram_stats) {
string word = a.first;
gram_error_counter u = a.second;
myfile << fmt::format("{0:>20}\t{1}/{2} ({3:.1f} %)\t{4}/{5} ({6:.1f} %)", word, u.correct,
(u.correct + u.ins + u.subst_fp), (float)u.precision, u.correct, (u.correct + u.del + u.subst_fn),
(float)u.recall)
<< endl;
}

myfile << string(60, '-') << endl << fmt::format("{0:>20}\t{1:20}\t{2:20}", "Bigram", "Precision", "Recall") << endl;

for (const auto &a : topAlignment.bigrams_stats) {
string word = a.first;
Expand All @@ -641,14 +630,4 @@ void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, st
jsonLogger::JsonLogger::getLogger().root["wer"]["bigrams"][word]["precision"] = u.precision;
jsonLogger::JsonLogger::getLogger().root["wer"]["bigrams"][word]["recall"] = u.recall;
}
for (const auto &a : topAlignment.bigrams_stats) {
string word = a.first;
gram_error_counter u = a.second;
myfile << fmt::format("{0:>20}\t{1}/{2} ({3:.1f} %)\t{4}/{5} ({6:.1f} %)", word, u.correct,
(u.correct + u.ins + u.subst_fp), (float)u.precision, u.correct, (u.correct + u.del + u.subst_fn),
(float)u.recall)
<< endl;
}

myfile.close();
}
1 change: 1 addition & 0 deletions src/wer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ typedef vector<pair<size_t, string>> ErrorGroups;

void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
void WriteSbs(wer_alignment &topAlignment, const vector<Stitching>& stitches, string sbs_filename);
void JsonLogUnigramBigramStats(wer_alignment &topAlignment);
2 changes: 1 addition & 1 deletion test/data/short.aligned.case.nlp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ sure|1|0.0000|0.0000|.||LC|[]|[]||||
When|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
hear|1|0.0000|0.0000|||LC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|['1', '2']||||
I|1|0.0000|0.0000|||CA|[]|[]||||
think|1|0.0000|0.0000|||LC|[]|[]||||
about|1|0.0000|0.0000|||LC|[]|[]||||
Expand Down
2 changes: 1 addition & 1 deletion test/data/short.aligned.punc.nlp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ sure|1|0.0000|0.0000|.||LC|[]|[]||||
When|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
hear|1|0.0000|0.0000|||LC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|['1', '2']||||
,|1|0.0000|0.0000|||||[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
think|1|0.0000|0.0000|||LC|[]|[]||||
Expand Down
2 changes: 1 addition & 1 deletion test/data/short.aligned.punc_case.nlp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ sure|1|0.0000|0.0000|.||LC|[]|[]||||
When|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
hear|1|0.0000|0.0000|||LC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|['1', '2']||||
,|1|0.0000|0.0000|||||[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
think|1|0.0000|0.0000|||LC|[]|[]||||
Expand Down
52 changes: 52 additions & 0 deletions test/data/short.sbs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
ref_token hyp_token IsErr Class Wer_Tag_Entities
<crosstalk> <crosstalk>
Yeah Yeah
, ,
yeah <del> ERR
, <del> ERR
right right
. <del> ERR
Yeah <del> ERR
, <del> ERR
all <del> ERR
right <del> ERR
, I'll ERR
probably do ERR
just just
that that
. ? ERR
Are Are
there there
any any
visuals visuals
that that
come come
to to
mind mind
or or ___100002_SYN_1-1___
<ins> ? ERR
Yeah Yeah
, ,
sure sure
. .
When When
I I
hear hear
Foobar Foobar ###1_PROPER_NOUN###|###2_SPACY>ORG###|
, ,
I I
think think
about about
just just
that that
: :
<ins> Foobar ERR
foo , ERR
a a
------------------------------------------------------------
Line Group
5 yeah , <-> ***
8 . Yeah , all right , probably <-> I'll do
17 . <-> ?
27 *** <-> ?
43 foo <-> Foobar ,
66 changes: 33 additions & 33 deletions test/data/short_punc.ref.nlp
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
token|speaker|ts|endTs|punctuation|case|tags
<crosstalk>|2||||LC|[]
Yeah|1|||,|UC|[]
yeah|1|||,|LC|[]
right|1|||.|LC|[]
Yeah|1|||,|UC|[]
all|1||||LC|[]
right|1|||,|LC|[]
probably|1||||LC|[]
just|1||||LC|[]
that|1|||.|LC|[]
Are|3||||UC|[]
there|3||||LC|[]
any|3||||LC|[]
visuals|3||||LC|[]
that|3||||LC|[]
come|3||||LC|[]
to|3||||LC|[]
mind|3||||LC|[]
or-|3||||LC|[]
Yeah|1|||,|UC|[]
sure|1|||.|LC|[]
When|1||||UC|[]
I|1||||CA|[]
hear|1||||LC|[]
Foobar|1|||,|UC|[]
I|1||||CA|[]
think|1||||LC|[]
about|1||||LC|[]
just|1||||LC|[]
that|1|||:|LC|[]
foo|1||||LC|[]
a|1||||LC|[]
token|speaker|ts|endTs|punctuation|case|tags|wer_tags
<crosstalk>|2||||LC|[]|[]
Yeah|1|||,|UC|[]|[]
yeah|1|||,|LC|[]|[]
right|1|||.|LC|[]|[]
Yeah|1|||,|UC|[]|[]
all|1||||LC|[]|[]
right|1|||,|LC|[]|[]
probably|1||||LC|[]|[]
just|1||||LC|[]|[]
that|1|||.|LC|[]|[]
Are|3||||UC|[]|[]
there|3||||LC|[]|[]
any|3||||LC|[]|[]
visuals|3||||LC|[]|[]
that|3||||LC|[]|[]
come|3||||LC|[]|[]
to|3||||LC|[]|[]
mind|3||||LC|[]|[]
or-|3||||LC|[]|[]
Yeah|1|||,|UC|[]|[]
sure|1|||.|LC|[]|[]
When|1||||UC|[]|[]
I|1||||CA|[]|[]
hear|1||||LC|[]|[]
Foobar|1|||,|UC|[]|['1', '2']
I|1||||CA|[]|[]
think|1||||LC|[]|[]
about|1||||LC|[]|[]
just|1||||LC|[]|[]
that|1|||:|LC|[]|[]
foo|1||||LC|[]|[]
a|1||||LC|[]|[]
8 changes: 8 additions & 0 deletions test/data/short_punc.wer_tag.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"1": {
"entity_type": "PROPER_NOUN"
},
"2": {
"entity_type": "SPACY>ORG"
}
}
Loading

0 comments on commit 363deb8

Please sign in to comment.