Skip to content

Commit

Permalink
Put wer tag entity type in SBS output (#32)
Browse files Browse the repository at this point in the history
* put wer tag entity type in sbs

* add example in doc

* Input flag documentation

* fix link

* fix header, unused code

* test headers
  • Loading branch information
nishchalb authored Jun 21, 2022
1 parent 2456389 commit 3796629
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 89 deletions.
7 changes: 7 additions & 0 deletions docs/Advanced-Usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ Normalizations are a similar concept to synonyms. They allow a token or group of
}
```

### WER Sidecar

CLI flag: `--wer-sidecar`

Only usable for NLP format reference files. This passes a [WER sidecar](https://github.com/revdotcom/fstalign/blob/develop/docs//NLP-Format.md#wer-tag-sidecar) file to
add extra information to some outputs. Optional.

## Outputs

### Text Log
Expand Down
17 changes: 16 additions & 1 deletion docs/NLP-Format.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,19 @@ first|0||||LC|['6:DATE']|['6']
quarter|0||||LC|['6:DATE']|['6']
2020|0||||CA|['0:YEAR']|['0', '1', '6']
NexGEn|0||||MC|['7:ORG']|['7']
```
```

## WER tag sidecar

WER tag sidecar files contain accompanying info for tokens in an NLP file. The
keys are IDs corresponding to tokens in the NLP file `wer_tags` column. The
objects under the keys are information about the token.

Example:
```
{
'0': {'entity_type': 'YEAR'},
'1': {'entity_type': 'CARDINAL'},
'6': {'entity_type': 'SPACY>TIME'},
}
```
21 changes: 18 additions & 3 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,36 @@
/***********************************
NLP FstLoader class start
************************************/
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization)
: NlpFstLoader(records, normalization, true) {}
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
Json::Value wer_sidecar)
: NlpFstLoader(records, normalization, wer_sidecar, true) {}

NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, bool processLabels)
NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
Json::Value wer_sidecar, bool processLabels)
: FstLoader() {
mNlpRows = records;
mJsonNorm = normalization;
mWerSidecar = wer_sidecar;
std::string last_label;
bool firstTk = true;


// fuse multiple rows that have the same id/label into one entry only
for (auto &row : mNlpRows) {
auto curr_tk = row.token;
auto curr_label = row.best_label;
auto curr_label_id = row.best_label_id;
auto curr_row_tags = row.wer_tags;
// Update wer tags in records to real string labels
vector<string> real_wer_tags;
for (auto &tag: curr_row_tags) {
auto real_tag = tag;
if (mWerSidecar != Json::nullValue) {
real_tag = "###"+ real_tag + "_" + mWerSidecar[real_tag]["entity_type"].asString() + "###";
}
real_wer_tags.push_back(real_tag);
}
row.wer_tags = real_wer_tags;
std::string speaker = row.speakerId;

if (processLabels && curr_label != "") {
Expand Down
5 changes: 3 additions & 2 deletions src/Nlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class NlpReader {

class NlpFstLoader : public FstLoader {
public:
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, bool processLabels);
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization);
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels);
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar);
virtual ~NlpFstLoader();
virtual void addToSymbolTable(fst::SymbolTable &symbol) const;
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const;
Expand All @@ -53,6 +53,7 @@ class NlpFstLoader : public FstLoader {
vector<RawNlpRecord> mNlpRows;
vector<std::string> mSpeakers;
Json::Value mJsonNorm;
Json::Value mWerSidecar;
virtual const std::string &getToken(int index) const { return mToken.at(index); }
};

Expand Down
46 changes: 23 additions & 23 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,34 +636,29 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine
CalculatePrecisionRecall(topAlignment, alignerOptions.pr_threshold);

RecordWer(topAlignment);
if (!output_sbs.empty()) {
logger->info("output_sbs = {}", output_sbs);
WriteSbs(topAlignment, output_sbs);
vector<shared_ptr<Stitching>> stitches;
CtmFstLoader *ctm_hyp_loader = dynamic_cast<CtmFstLoader *>(hypLoader);
NlpFstLoader *nlp_hyp_loader = dynamic_cast<NlpFstLoader *>(hypLoader);
OneBestFstLoader *best_loader = dynamic_cast<OneBestFstLoader *>(hypLoader);
if (ctm_hyp_loader) {
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {});
} else if (nlp_hyp_loader) {
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows);
} else if (best_loader) {
vector<string> tokens;
tokens.reserve(best_loader->TokensSize());
for (int i = 0; i < best_loader->TokensSize(); i++) {
string token = best_loader->getToken(i);
tokens.push_back(token);
}
stitches = make_stitches(topAlignment, {}, {}, tokens);
} else {
stitches = make_stitches(topAlignment);
}

NlpFstLoader *nlp_ref_loader = dynamic_cast<NlpFstLoader *>(refLoader);
if (nlp_ref_loader) {
// We have an NLP reference, more metadata (e.g. speaker info) is available
vector<shared_ptr<Stitching>> stitches;
CtmFstLoader *ctm_hyp_loader = dynamic_cast<CtmFstLoader *>(hypLoader);
NlpFstLoader *nlp_hyp_loader = dynamic_cast<NlpFstLoader *>(hypLoader);
OneBestFstLoader *best_loader = dynamic_cast<OneBestFstLoader *>(hypLoader);
if (ctm_hyp_loader) {
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {});
} else if (nlp_hyp_loader) {
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows);
} else if (best_loader) {
vector<string> tokens;
tokens.reserve(best_loader->TokensSize());
for (int i = 0; i < best_loader->TokensSize(); i++) {
string token = best_loader->getToken(i);
tokens.push_back(token);
}
stitches = make_stitches(topAlignment, {}, {}, tokens);
} else {
stitches = make_stitches(topAlignment);
}

// Align stitches to the NLP, so stitches can access metadata
try {
align_stitches_to_nlp(nlp_ref_loader, &stitches);
Expand Down Expand Up @@ -693,6 +688,11 @@ void HandleWer(FstLoader *refLoader, FstLoader *hypLoader, SynonymEngine *engine
}
}

if (!output_sbs.empty()) {
logger->info("output_sbs = {}", output_sbs);
WriteSbs(topAlignment, stitches, output_sbs);
}

if (!output_nlp.empty() && !nlp_ref_loader) {
logger->warn("Attempted to output an Aligned NLP file without NLP reference, skipping output.");
}
Expand Down
36 changes: 34 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ int main(int argc, char **argv) {
setlocale(LC_ALL, "en_US.UTF-8");
string ref_filename;
string json_norm_filename;
string wer_sidecar_filename;
string hyp_filename;
string log_filename = "";
string output_nlp = "";
Expand Down Expand Up @@ -94,6 +95,8 @@ int main(int argc, char **argv) {
c->add_option("--composition-approach", composition_approach,
"Desired composition logic. Choices are 'standard' or 'adapted'");
}
get_wer->add_option("--wer-sidecar", wer_sidecar_filename,
"WER sidecar json file.");

get_wer->add_option("--speaker-switch-context", speaker_switch_context_size,
"Amount of context (in each direction) around "
Expand Down Expand Up @@ -166,6 +169,27 @@ int main(int argc, char **argv) {
Json::parseFromStream(builder, ss, &obj, &errs);
}

Json::Value wer_sidecar_obj;
if (!wer_sidecar_filename.empty()) {
console->info("reading wer sidecar info from {}", wer_sidecar_filename);
ifstream ifs(wer_sidecar_filename);

Json::CharReaderBuilder builder;
builder["collectComments"] = false;

JSONCPP_STRING errs;
Json::parseFromStream(builder, ifs, &wer_sidecar_obj, &errs);

console->info("The json we just read [{}] has {} elements from its root", wer_sidecar_filename, wer_sidecar_obj.size());
} else {
stringstream ss;
ss << "{}";

Json::CharReaderBuilder builder;
JSONCPP_STRING errs;
Json::parseFromStream(builder, ss, &wer_sidecar_obj, &errs);
}

Json::Value hyp_json_obj;
if (!hyp_json_norm_filename.empty()) {
console->info("reading hypothesis json norm info from {}", hyp_json_norm_filename);
Expand Down Expand Up @@ -194,7 +218,7 @@ int main(int argc, char **argv) {
NlpReader nlpReader = NlpReader();
console->info("reading reference nlp from {}", ref_filename);
auto vec = nlpReader.read_from_disk(ref_filename);
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, true);
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, wer_sidecar_obj, true);
ref = nlpFst;
} else if (EndsWithCaseInsensitive(ref_filename, string(".ctm"))) {
console->info("reading reference ctm from {}", ref_filename);
Expand All @@ -212,11 +236,19 @@ int main(int argc, char **argv) {
// loading "hypothesis" inputs
if (EndsWithCaseInsensitive(hyp_filename, string(".nlp"))) {
console->info("reading hypothesis nlp from {}", hyp_filename);
// Make empty json for wer sidecar
Json::Value hyp_empty_json;
stringstream ss;
ss << "{}";

Json::CharReaderBuilder builder;
JSONCPP_STRING errs;
Json::parseFromStream(builder, ss, &hyp_empty_json, &errs);
NlpReader nlpReader = NlpReader();
auto vec = nlpReader.read_from_disk(hyp_filename);
// for now, nlp files passed as hypothesis won't have their labels handled as such
// this also mean that json normalization will be ignored
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, false);
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, hyp_empty_json, false);
hyp = nlpFst;
} else if (EndsWithCaseInsensitive(hyp_filename, string(".ctm"))) {
console->info("reading hypothesis ctm from {}", hyp_filename);
Expand Down
32 changes: 20 additions & 12 deletions src/wer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,19 @@ void RecordTagWer(vector<shared_ptr<Stitching>> stitches) {
for (auto &stitch : stitches) {
if (!stitch->nlpRow.wer_tags.empty()) {
for (auto wer_tag : stitch->nlpRow.wer_tags) {
wer_results.insert(std::pair<std::string, WerResult>(wer_tag, {0, 0, 0, 0, 0}));
int tag_start = wer_tag.find_first_not_of('#');
int tag_end = wer_tag.find('_');
string wer_tag_id = wer_tag.substr(tag_start, tag_end - tag_start);
wer_results.insert(std::pair<std::string, WerResult>(wer_tag_id, {0, 0, 0, 0, 0}));
// Check with rfind since other comments can be there
bool del = stitch->comment.rfind("del", 0) == 0;
bool ins = stitch->comment.rfind("ins", 0) == 0;
bool sub = stitch->comment.rfind("sub", 0) == 0;
wer_results[wer_tag].insertions += ins;
wer_results[wer_tag].deletions += del;
wer_results[wer_tag].substitutions += sub;
wer_results[wer_tag_id].insertions += ins;
wer_results[wer_tag_id].deletions += del;
wer_results[wer_tag_id].substitutions += sub;
if (!ins) {
wer_results[wer_tag].numWordsInReference += 1;
wer_results[wer_tag_id].numWordsInReference += 1;
}
}
}
Expand Down Expand Up @@ -503,7 +506,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp)
hyp = "";
}

void WriteSbs(spWERA topAlignment, string sbs_filename) {
void WriteSbs(spWERA topAlignment, vector<shared_ptr<Stitching>> stitches, string sbs_filename) {
auto logger = logger::GetOrCreateLogger("wer");
logger->set_level(spdlog::level::info);

Expand All @@ -514,7 +517,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
triple *tk_pair = new triple();
string prev_tk_classLabel = "";
logger->info("Side-by-Side alignment info going into {}", sbs_filename);
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}", "ref_token", "hyp_token", "IsErr", "Class") << endl;
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", "ref_token", "hyp_token", "IsErr", "Class", "Wer_Tag_Entities") << endl;

// keep track of error groupings
ErrorGroups groups_err;
Expand All @@ -525,10 +528,15 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
std::set<std::string> op_set = {"<ins>", "<del>", "<sub>"};

size_t offset = 2; // line number in output file where first triple starts
while (visitor.NextTriple(tk_pair)) {
string tk_classLabel = tk_pair->classLabel;
string ref_tk = tk_pair->ref;
string hyp_tk = tk_pair->hyp;
for (auto p_stitch: stitches) {
string tk_classLabel = p_stitch->classLabel;
string tk_wer_tags = "";
auto wer_tags = p_stitch->nlpRow.wer_tags;
for (auto wer_tag: wer_tags) {
tk_wer_tags = tk_wer_tags + wer_tag + "|";
}
string ref_tk = p_stitch->reftk;
string hyp_tk = p_stitch->hyptk;
string tag = "";

if (ref_tk == NOOP) {
Expand Down Expand Up @@ -560,7 +568,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
eff_class = tk_classLabel;
}

myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}", ref_tk, hyp_tk, tag, eff_class) << endl;
myfile << fmt::format("{0:>20}\t{1:20}\t{2}\t{3}\t{4}", ref_tk, hyp_tk, tag, eff_class, tk_wer_tags) << endl;
offset++;
}

Expand Down
2 changes: 1 addition & 1 deletion src/wer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ void CalculatePrecisionRecall(spWERA &topAlignment, int threshold);
typedef vector<pair<size_t, string>> ErrorGroups;

void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp);
void WriteSbs(spWERA topAlignment, string sbs_filename);
void WriteSbs(spWERA topAlignment, vector<shared_ptr<Stitching>> stitches, string sbs_filename);
50 changes: 25 additions & 25 deletions test/data/syn_1.hyp.sbs
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
ref_token hyp_token IsErr Class
we <del> ERR
will we'll ERR
have have
a a
nice nice
evening evening
<ins> um ERR
no no
matter matter
what what
will will
happen happen
<ins> it ERR
um is ERR
it's uh ERR
a a
good good
opportunity opportunity
to to
do <del> ERR
this this
you'll you'll
<ins> uh ERR
see see
ref_token hyp_token IsErr Class Wer_Tag_Entities
we <del> ERR
will we'll ERR
have have
a a
nice nice
evening evening
<ins> um ERR
no no
matter matter
what what
will will
happen happen
<ins> it ERR
um is ERR
it's uh ERR
a a
good good
opportunity opportunity
to to
do <del> ERR
this this
you'll you'll
<ins> uh ERR
see see
------------------------------------------------------------
Line Group
2 we will <-> we'll
Expand Down
20 changes: 10 additions & 10 deletions test/data/twenty.hyp-a2.sbs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
ref_token hyp_token IsErr Class
20 <del> ERR ___1_CARDINAL___
in in
twenty twenty ___2_YEAR___
twenty thirty ERR ___2_YEAR___
is is
one one ___3_CARDINAL___
twenty twenty ___3_CARDINAL___
<ins> two ERR ___3_CARDINAL___
three three ___3_CARDINAL___
ref_token hyp_token IsErr Class Wer_Tag_Entities
20 <del> ERR ___1_CARDINAL___
in in
twenty twenty ___2_YEAR___
twenty thirty ERR ___2_YEAR___
is is
one one ___3_CARDINAL___
twenty twenty ___3_CARDINAL___
<ins> two ERR ___3_CARDINAL___
three three ___3_CARDINAL___
------------------------------------------------------------
Line Group
2 20 <-> ***
Expand Down
Loading

0 comments on commit 3796629

Please sign in to comment.