Skip to content

Commit

Permalink
Smallish fixes to clstmctc.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbdev committed Sep 24, 2015
1 parent fa76c70 commit d49ebf4
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions clstmctc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ int main_ocr(int argc, char **argv) {
bool randomize = getienv("randomize", 1);
string lrnorm = getoneof("lrnorm", "batch");
string dewarp = getoneof("dewarp", "none");
string net_type = getoneof("lstm", "BIDILSTM");
string net_type = getoneof("lstm", "bidi");
string lstm_type = getoneof("lstm_type", "LSTM");
string output_type = getoneof("output_type", "SoftmaxLayer");
string target_name = getoneof("target_name", "ctc");
Expand Down Expand Up @@ -310,12 +310,13 @@ int main_ocr(int argc, char **argv) {
save_net(fname, net);
if (after_save != "") system(after_save.c_str());
}
if (trial > 0 && test_every > 0 && trial % test_every == 0 &&
testset != "") {
if (testset != "" &&
((trial > 0 && test_every > 0 && trial % test_every == 0) ||
(trial == ntrain-1))) {
double erate = error_rate(net, testset);
net->attributes["trial"] = to_string(trial);
net->attributes["last_err"] = to_string(best_erate);
print("TESTERR", now() - start_time, save_name, trial, erate, "lrate",
print("TSETERR", erate, "@", trial, ":", now() - start_time, "lrate",
lrate, "hidden", nhidden, nhidden2, "pseudo_batch", pseudo_batch,
"momentum", momentum);
if (save_every == 0 && erate < best_erate) {
Expand Down Expand Up @@ -353,9 +354,8 @@ int main_ocr(int argc, char **argv) {
break;
}
assert(saligned.size() == net->outputs.size());
net->d_outputs.resize(net->outputs.size());
for (int t = 0; t < saligned.size(); t++)
net->d_outputs[t] = saligned[t] - net->outputs[t];
net->outputs[t].d = saligned[t] - net->outputs[t];
net->backward();
if (trial % pseudo_batch == 0) net->update();
Classes output_classes, aligned_classes;
Expand All @@ -372,7 +372,6 @@ int main_ocr(int argc, char **argv) {
}

if (display_every > 0 && trial % display_every == 0) {
net->d_outputs.resize(saligned.size());
py->eval("clf()");
py->subplot(4, 1, 1);
py->evalf("title('%s')", gt.c_str());
Expand Down

0 comments on commit d49ebf4

Please sign in to comment.