Skip to content

Commit

Permalink
FSTALIGN-61: Fewer explicit heap allocations (#45)
Browse files Browse the repository at this point in the history
* wip

* debugging

* Revert "debugging"

This reverts commit 786016b.

* fixing bugs - tests now passing

* More refactoring

* Removing more unnecessary shared_ptrs
  • Loading branch information
dchen579 authored May 23, 2023
1 parent 1e6a47a commit bf4979a
Show file tree
Hide file tree
Showing 17 changed files with 469 additions and 471 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")

enable_testing()

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

if(DEFINED ENV{OPENFST_ROOT})
Expand Down
34 changes: 16 additions & 18 deletions src/AlignmentTraversor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ AlignmentTraversor.cpp
*/
#include "AlignmentTraversor.h"

AlignmentTraversor::AlignmentTraversor(spWERA topLevel) {
root = topLevel;
AlignmentTraversor::AlignmentTraversor(wer_alignment &topLevel) : root(topLevel) {
currentPosInRoot = -1;
currentSubclass = nullptr;
}
Expand All @@ -18,43 +17,42 @@ void AlignmentTraversor::Restart() {
currentPosInSubclass = -1;
}

bool AlignmentTraversor::NextTriple(triple *triple) {
bool AlignmentTraversor::NextTriple(triple &triple) {
if (currentSubclass == nullptr) {
// we're not in a subclass, we're consuming the root alignment content,
// let's move to the next word
currentPosInRoot++;
if (currentPosInRoot >= root->tokens.size()) {
triple = nullptr;
if (currentPosInRoot >= root.tokens.size()) {
return false;
}

auto tk = root->tokens[currentPosInRoot];
auto tk = root.tokens[currentPosInRoot];
if (isEntityLabel(tk.first)) {
// handle class
currentPosInSubclass = -1;
// find subclass spWERA from within the root
for (auto &a : root->label_alignments) {
if (a->classLabel == tk.first) {
currentSubclass = a;
for (auto &a : root.label_alignments) {
if (a.classLabel == tk.first) {
currentSubclass = &a;
break;
}
}
// currentSubclass = nullptr; // fixme
return NextTriple(triple);
}

triple->classLabel = TK_GLOBAL_CLASS;
triple->ref = tk.first;
triple->hyp = tk.second;
triple.classLabel = TK_GLOBAL_CLASS;
triple.ref = tk.first;
triple.hyp = tk.second;

return true;
} else {
currentPosInSubclass++;
if (currentPosInSubclass == 0 && currentSubclass->tokens.size() == 0 &&
currentSubclass->classLabel.find("FALLBACK") != std::string::npos) {
triple->classLabel = currentSubclass->classLabel;
triple->ref = NOOP;
triple->hyp = NOOP;
triple.classLabel = currentSubclass->classLabel;
triple.ref = NOOP;
triple.hyp = NOOP;
return true;
}
if (currentPosInSubclass >= currentSubclass->tokens.size()) {
Expand All @@ -65,9 +63,9 @@ bool AlignmentTraversor::NextTriple(triple *triple) {
}

auto tk = currentSubclass->tokens[currentPosInSubclass];
triple->classLabel = currentSubclass->classLabel;
triple->ref = tk.first;
triple->hyp = tk.second;
triple.classLabel = currentSubclass->classLabel;
triple.ref = tk.first;
triple.hyp = tk.second;
return true;
}
}
8 changes: 4 additions & 4 deletions src/AlignmentTraversor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ struct triple {

class AlignmentTraversor {
public:
AlignmentTraversor(spWERA topLevel);
bool NextTriple(triple *triple);
AlignmentTraversor(wer_alignment &topLevel);
bool NextTriple(triple &triple);
void Restart();

private:
spWERA root;
wer_alignment &root;
int currentPosInRoot = -1;
int currentPosInSubclass;
spWERA currentSubclass;
wer_alignment *currentSubclass;
};

#endif // __ATRAVERSOR_H__
13 changes: 13 additions & 0 deletions src/FstLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ class FstLoader {
static void AddSymbolIfNeeded(fst::SymbolTable &symbol, std::string str_value);
virtual fst::StdVectorFst convertToFst(const fst::SymbolTable &symbol, std::vector<int> map) const = 0;
virtual std::vector<int> convertToIntVector(fst::SymbolTable &symbol) const = 0;

static std::unique_ptr<FstLoader> MakeReferenceLoader(const std::string& ref_filename,
const std::string& wer_sidecar_filename,
const std::string& json_norm_filename,
bool use_punctuation,
bool symbols_file_included);

static std::unique_ptr<FstLoader> MakeHypothesisLoader(const std::string& hyp_filename,
const std::string& hyp_json_norm_filename,
bool use_punctuation,
bool symbols_file_included);


};

#endif /* __FSTLOADER_H_ */
25 changes: 12 additions & 13 deletions src/PathHeap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,29 @@ using namespace fst;

PathHeap::PathHeap() {
// creating the set
heap = new set<shared_ptr<ShortlistEntry>, shortlistComparatorSharedPtr>();
}

void PathHeap::insert(shared_ptr<ShortlistEntry> entry) {
// just add it to the set, leaving to the comparator to do its job
heap->insert(entry);
heap.insert(entry);
}

shared_ptr<ShortlistEntry> PathHeap::removeFirst() {
// we want to take the 1st element and remove it
auto logbookIter = heap->begin();
auto logbookIter = heap.begin();
auto currentState_ptr = *logbookIter;
heap->erase(logbookIter);
heap.erase(logbookIter);
return currentState_ptr;
}

int PathHeap::size() { return heap->size(); }
int PathHeap::size() { return heap.size(); }

shared_ptr<ShortlistEntry> PathHeap::GetBestWerCandidate() {
set<shared_ptr<ShortlistEntry>, shortlistComparatorSharedPtr>::iterator iter = heap->begin();
set<shared_ptr<ShortlistEntry>, shortlistComparatorSharedPtr>::iterator iter = heap.begin();

shared_ptr<ShortlistEntry> best = nullptr;
float bestWer = std::numeric_limits<float>::quiet_NaN();
while (iter != heap->end()) {
while (iter != heap.end()) {
auto entry = *iter;
float local_wer = (float)entry->numErrors / (float)entry->numWords;

Expand All @@ -57,9 +56,9 @@ shared_ptr<ShortlistEntry> PathHeap::GetBestWerCandidate() {
}

int PathHeap::prune(int targetSz) {
set<shared_ptr<ShortlistEntry>, shortlistComparatorSharedPtr>::iterator iter = heap->begin();
set<shared_ptr<ShortlistEntry>, shortlistComparatorSharedPtr>::iterator iter = heap.begin();
float wer0, wer_last;
int sz = heap->size();
int sz = heap.size();
for (int i = 0; i < targetSz && i < sz; i++) {
float local_wer = (float)(*iter)->numErrors / ((float)(*iter)->numWords);
if (i == 0) {
Expand All @@ -77,15 +76,15 @@ int PathHeap::prune(int targetSz) {
// logger->set_level(spdlog::level::debug);
logger->debug("==== pruning starting =====");
logger->debug("pruning to {} items -> top wer was {} and last wer was {}. We have {} items in the heap.", targetSz,
wer0, wer_last, heap->size());
wer0, wer_last, heap.size());

/* TODO: make sure we don't prune paths that have the same length/error-count
as the last one kept at 'targetSz'
*/

int numErrorsWithoutInsertions = (*last_wer_index)->numErrors - (*last_wer_index)->numInsert;
int pruned = 0;
while (iter != heap->end()) {
while (iter != heap.end()) {
auto p = *iter;
float local_wer = (float)(*iter)->numErrors / ((float)(*iter)->numWords);
logger->debug(
Expand All @@ -104,13 +103,13 @@ bool pruneMe = (*last_wer_index)->numErrors + 20 < (*iter)->numErrors; --> seem
bool pruneMe = (*last_wer_index)->numErrors + 20 < (*iter)->numErrors;
logger->debug("{} + 20 < {} = {}", numErrorsWithoutInsertions, localCoreErr, pruneMe);
if (pruneMe) {
heap->erase(iter++);
heap.erase(iter++);
pruned++;
} else {
iter++;
}
}
logger->debug("after pruning we have {} items in the heap", heap->size());
logger->debug("after pruning we have {} items in the heap", heap.size());
logger->debug("-----");

return pruned;
Expand Down
14 changes: 7 additions & 7 deletions src/PathHeap.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ typedef struct MyArc* MyArcPtr;
typedef shared_ptr<ShortlistEntry> spSLE;

struct MyArc {
int ilabel;
int olabel;
float weight;
int nextstate;
int ilabel = 0;
int olabel = 0;
float weight = 0.0;
int nextstate = 0;
};

struct ShortlistEntry {
Expand All @@ -36,7 +36,7 @@ struct ShortlistEntry {
int numInsert = 0;
double costToGoThere = 0;
float costSoFar = 0;
shared_ptr<MyArc> local_arc;
MyArc local_arc;
shared_ptr<ShortlistEntry> linkToHere = nullptr;
};

Expand Down Expand Up @@ -70,6 +70,6 @@ class PathHeap {
bool pruningIncludeInsInThreshold = true;

private:
set<std::shared_ptr<ShortlistEntry>, shortlistComparatorSharedPtr>* heap;
set<std::shared_ptr<ShortlistEntry>, shortlistComparatorSharedPtr> heap;
};
#endif // __PATH_HEAP_H__
#endif // __PATH_HEAP_H__
8 changes: 2 additions & 6 deletions src/StandardComposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ StandardCompositionFst::StandardCompositionFst(const fst::StdFst &fstA, const fs
}

logger_->info("performing lazy composition");
fstC_ = new fst::StdComposeFst(halfCompose1, halfCompose2);
fstC_ = std::make_unique<fst::StdComposeFst>(halfCompose1, halfCompose2);

// initialize internal stores. if we don't initialize the state iterator
// (even if we don't really use it) then any call to ArcIterator(fst,
Expand All @@ -130,11 +130,7 @@ bool StandardCompositionFst::TryGetArcsAtState(StateId fromStateId, vector<fst::
return true;
}

StandardCompositionFst::~StandardCompositionFst() {
if (fstC_ != NULL) {
delete fstC_;
}
}
StandardCompositionFst::~StandardCompositionFst() {}

void StandardCompositionFst::DebugComposedGraph(string debug_filename) {
StdVectorFst composedFst(*fstC_);
Expand Down
4 changes: 2 additions & 2 deletions src/StandardComposition.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class StandardCompositionFst : public IComposition {
protected:
// Lazily composed fst, created during initialization
fst::StdComposeFst *fstC_;
std::unique_ptr<fst::StdComposeFst> fstC_;

public:
StandardCompositionFst(const fst::StdFst &fstA, const fst::StdFst &fstB);
Expand All @@ -42,4 +42,4 @@ class StandardCompositionFst : public IComposition {
void DebugComposedGraph(string debug_filename);
};

#endif /* __STANDARDCOMPOSITION_H__ */
#endif /* __STANDARDCOMPOSITION_H__ */
Loading

0 comments on commit bf4979a

Please sign in to comment.