Skip to content

Commit

Permalink
Nerd-1422: Add flag for reading punctuation from nlp as separate toke…
Browse files Browse the repository at this point in the history
…ns (#10)

* Nerd-1422: Add flag for reading punctuation from nlp as separate tokens

* test

* version file

---------

Co-authored-by: Nishchal Bhandari <nishchal2050@gmail.com>
  • Loading branch information
ajhinsvark and nishchalb authored Apr 20, 2023
1 parent 99afe1b commit 4c579ad
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 8 deletions.
15 changes: 12 additions & 3 deletions src/Nlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
: NlpFstLoader(records, normalization, wer_sidecar, true) {}

NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization,
Json::Value wer_sidecar, bool processLabels)
Json::Value wer_sidecar, bool processLabels, bool use_punctuation)
: FstLoader() {
mNlpRows = records;
mJsonNorm = normalization;
mWerSidecar = wer_sidecar;
std::string last_label;
bool firstTk = true;


// fuse multiple rows that have the same id/label into one entry only
for (auto &row : mNlpRows) {
for (auto &row : records) {
mNlpRows.push_back(row);
auto curr_tk = row.token;
auto curr_label = row.best_label;
auto curr_label_id = row.best_label_id;
auto punctuation = row.punctuation;
auto curr_row_tags = row.wer_tags;
// Update wer tags in records to real string labels
vector<string> real_wer_tags;
Expand Down Expand Up @@ -83,6 +84,14 @@ NlpFstLoader::NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value norma
std::string lower_cased = UnicodeLowercase(curr_tk);
mToken.push_back(lower_cased);
mSpeakers.push_back(speaker);
if (use_punctuation && punctuation != "") {
mToken.push_back(punctuation);
mSpeakers.push_back(speaker);
RawNlpRecord nlp_row = row;
nlp_row.token = nlp_row.punctuation;
nlp_row.punctuation = "";
mNlpRows.push_back(nlp_row);
}
}

firstTk = false;
Expand Down
2 changes: 1 addition & 1 deletion src/Nlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class NlpReader {

class NlpFstLoader : public FstLoader {
public:
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels);
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar, bool processLabels, bool use_punctuation = false);
NlpFstLoader(std::vector<RawNlpRecord> &records, Json::Value normalization, Json::Value wer_sidecar);
virtual ~NlpFstLoader();
virtual void addToSymbolTable(fst::SymbolTable &symbol) const;
Expand Down
4 changes: 4 additions & 0 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ void align_stitches_to_nlp(NlpFstLoader *refLoader, vector<shared_ptr<Stitching>
continue;
}

if (nlpRowIndex >= nlpMaxRow) {
logger->warn("Ran out of nlp rows. {} rows, {} stitches", nlpMaxRow, numStitches);
break;
}
auto nlpPart = nlpRows[nlpRowIndex];
string nlp_classLabel = GetClassLabel(nlpPart.best_label);

Expand Down
6 changes: 4 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ int main(int argc, char **argv) {
int numBests = 100;
int levenstein_maximum_error_streak = 100;
bool record_case_stats = false;
bool use_punctuation = false;
bool disable_approximate_alignment = false;

bool disable_cutoffs = false;
Expand Down Expand Up @@ -120,6 +121,7 @@ int main(int argc, char **argv) {
get_wer->add_flag("--record-case-stats", record_case_stats,
"Record precision/recall for how well the hypothesis"
"casing matches the reference.");
get_wer->add_flag("--use-punctuation", use_punctuation, "Treat punctuation from nlp rows as separate tokens");

// CLI11_PARSE(app, argc, argv);
try {
Expand Down Expand Up @@ -218,7 +220,7 @@ int main(int argc, char **argv) {
NlpReader nlpReader = NlpReader();
console->info("reading reference nlp from {}", ref_filename);
auto vec = nlpReader.read_from_disk(ref_filename);
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, wer_sidecar_obj, true);
NlpFstLoader *nlpFst = new NlpFstLoader(vec, obj, wer_sidecar_obj, true, use_punctuation);
ref = nlpFst;
} else if (EndsWithCaseInsensitive(ref_filename, string(".ctm"))) {
console->info("reading reference ctm from {}", ref_filename);
Expand Down Expand Up @@ -248,7 +250,7 @@ int main(int argc, char **argv) {
auto vec = nlpReader.read_from_disk(hyp_filename);
// for now, nlp files passed as hypothesis won't have their labels handled as such
// this also mean that json normalization will be ignored
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, hyp_empty_json, false);
NlpFstLoader *nlpFst = new NlpFstLoader(vec, hyp_json_obj, hyp_empty_json, false, use_punctuation);
hyp = nlpFst;
} else if (EndsWithCaseInsensitive(hyp_filename, string(".ctm"))) {
console->info("reading hypothesis ctm from {}", hyp_filename);
Expand Down
4 changes: 2 additions & 2 deletions src/version.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#define FSTALIGNER_VERSION_MAJOR 1
#define FSTALIGNER_VERSION_MINOR 6
#define FSTALIGNER_VERSION_PATCH 1
#define FSTALIGNER_VERSION_MINOR 9
#define FSTALIGNER_VERSION_PATCH 0
43 changes: 43 additions & 0 deletions test/data/short.aligned.punc.nlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
token|speaker|ts|endTs|punctuation|prepunctuation|case|tags|wer_tags|oldTs|oldEndTs|ali_comment
<crosstalk>|2|0.0000|0.0000|||LC|[]|[]|||
Yeah|1|0.0000|0.0000|,||UC|[]|[]|||
,|1|0.0000|0.0000|||UC|[]|[]|||
yeah|1|||,||LC|[]|[]|||del
,|1|||||LC|[]|[]|||del
right|1|0.0000|0.0000|.||LC|[]|[]|||
.|1|||||LC|[]|[]|||del
Yeah|1|||,||UC|[]|[]|||del
,|1|||||UC|[]|[]|||del
all|1|||||LC|[]|[]|||del
right|1|||,||LC|[]|[]|||del
,|1|0.0000|0.0000|||LC|[]|[]|||sub(i'll)
probably|1|0.0000|0.0000|||LC|[]|[]|||sub(do)
just|1|0.0000|0.0000|||LC|[]|[]|||
that|1|0.0000|0.0000|.||LC|[]|[]|||
.|1|0.0000|0.0000|||LC|[]|[]|||sub(?)
Are|3|0.0000|0.0000|||UC|[]|[]|||
there|3|0.0000|0.0000|||LC|[]|[]|||
any|3|0.0000|0.0000|||LC|[]|[]|||
visuals|3|0.0000|0.0000|||LC|[]|[]|||
that|3|0.0000|0.0000|||LC|[]|[]|||
come|3|0.0000|0.0000|||LC|[]|[]|||
to|3|0.0000|0.0000|||LC|[]|[]|||
mind|3|0.0000|0.0000|||LC|[]|[]|||
or|3|0.0000|0.0000|||LC|[]|[]|||
Yeah|1|0.0000|0.0000|,||UC|[]|[]|||
,|1|0.0000|0.0000|||UC|[]|[]|||
sure|1|0.0000|0.0000|.||LC|[]|[]|||
.|1|0.0000|0.0000|||LC|[]|[]|||
When|1|0.0000|0.0000|||UC|[]|[]|||
I|1|0.0000|0.0000|||CA|[]|[]|||
hear|1|0.0000|0.0000|||LC|[]|[]|||
Foobar|1|0.0000|0.0000|,||UC|[]|[]|||
,|1|0.0000|0.0000|||UC|[]|[]|||
I|1|0.0000|0.0000|||CA|[]|[]|||
think|1|0.0000|0.0000|||LC|[]|[]|||
about|1|0.0000|0.0000|||LC|[]|[]|||
just|1|0.0000|0.0000|||LC|[]|[]|||
that|1|0.0000|0.0000|:||LC|[]|[]|||
:|1|0.0000|0.0000|||LC|[]|[]|||
foo|1|0.0000|0.0000|||LC|[]|[]|||sub(,)
a|1|0.0000|0.0000|||LC|[]|[]|||
30 changes: 30 additions & 0 deletions test/data/short_punc.hyp.nlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
token|speaker|ts|endTs|punctuation|case|tags
<crosstalk>|2||||LC|[]
Yeah|1|||,|UC|[]
right|1||||LC|[]
I'll|1||||UC|[]
do|1||||LC|[]
just|1||||LC|[]
that|1|||?|LC|[]
Are|3||||UC|[]
there|3||||LC|[]
any|3||||LC|[]
visuals|3||||LC|[]
that|3||||LC|[]
come|3||||LC|[]
to|3||||LC|[]
mind|3||||LC|[]
or|3|||?|LC|[]
Yeah|1|||,|UC|[]
sure|1|||.|LC|[]
When|1||||UC|[]
I|1||||CA|[]
hear|1||||LC|[]
Foobar|1|||,|UC|[]
I|1||||CA|[]
think|1||||LC|[]
about|1||||LC|[]
just|1||||LC|[]
that|1|||:|LC|[]
Foobar|1|||,|UC|[]
a|1||||LC|[]
33 changes: 33 additions & 0 deletions test/data/short_punc.ref.nlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
token|speaker|ts|endTs|punctuation|case|tags
<crosstalk>|2||||LC|[]
Yeah|1|||,|UC|[]
yeah|1|||,|LC|[]
right|1|||.|LC|[]
Yeah|1|||,|UC|[]
all|1||||LC|[]
right|1|||,|LC|[]
probably|1||||LC|[]
just|1||||LC|[]
that|1|||.|LC|[]
Are|3||||UC|[]
there|3||||LC|[]
any|3||||LC|[]
visuals|3||||LC|[]
that|3||||LC|[]
come|3||||LC|[]
to|3||||LC|[]
mind|3||||LC|[]
or-|3||||LC|[]
Yeah|1|||,|UC|[]
sure|1|||.|LC|[]
When|1||||UC|[]
I|1||||CA|[]
hear|1||||LC|[]
Foobar|1|||,|UC|[]
I|1||||CA|[]
think|1||||LC|[]
about|1||||LC|[]
just|1||||LC|[]
that|1|||:|LC|[]
foo|1||||LC|[]
a|1||||LC|[]
10 changes: 10 additions & 0 deletions test/fstalign_Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,16 @@ TEST_CASE_METHOD(UniqueTestsFixture, "main-adapted-composition()") {
REQUIRE_THAT(result, Contains("WER: INS:0 DEL:3 SUB:3"));
}

SECTION("wer with punctuation(nlp output)") {
const auto result =
exec(command("wer", approach, "short_punc.ref.nlp", "short_punc.hyp.nlp", sbs_output, nlp_output, TEST_SYNONYMS)+" --use-punctuation");
const auto testFile = std::string{TEST_DATA} + "short.aligned.punc.nlp";

REQUIRE(compareFiles(nlp_output.c_str(), testFile.c_str()));
REQUIRE_THAT(result, Contains("WER: 13/42 = 0.3095"));
REQUIRE_THAT(result, Contains("WER: INS:2 DEL:7 SUB:4"));
}

// alignment tests

SECTION("align_1") {
Expand Down

0 comments on commit 4c579ad

Please sign in to comment.