Skip to content

Commit

Permalink
Added NDEBUG, functions for merging models.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbdev committed Oct 26, 2015
1 parent fee039e commit 484108c
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 16 deletions.
2 changes: 1 addition & 1 deletion SConstruct
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ elif option("debug", 0)>0:
env.Append(CCFLAGS="-g".split())
env.Append(LINKFLAGS="-g".split())
else:
env.Append(CXXFLAGS="-g -O3 -finline".split())
env.Append(CXXFLAGS="-g -O3 -DNDEBUG -finline".split())
env.Append(CCFLAGS="-g".split())

# Extra layers (old layers or testing)
Expand Down
29 changes: 20 additions & 9 deletions clstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,25 @@ static void get_allparams(vector<vector<Params*>> &allparams, vector<Network> &n
}
}

void share_deltas(vector<Network> networks) {
void distribute_weights(vector<Network> &networks, int from) {
vector<vector<Params*>> allparams;
get_allparams(allparams, networks);
int n = allparams.size();
int m = allparams[0].size();
for(int i=1; n; i++) {
for(int i=0; i<n; i++) {
if (i==from) continue;
for(int j=0; j<m; j++) {
allparams[i][j]->V() = allparams[from][j]->V();
}
}
}

void share_deltas(vector<Network> &networks) {
vector<vector<Params*>> allparams;
get_allparams(allparams, networks);
int n = allparams.size();
int m = allparams[0].size();
for(int i=1; i<n; i++) {
for(int j=0; j<m; j++) {
allparams[0][j]->D() += allparams[i][j]->D();
}
Expand All @@ -555,7 +568,7 @@ void share_deltas(vector<Network> networks) {
}
}

void average_weights(vector<Network> networks) {
void average_weights(vector<Network> &networks) {
vector<vector<Params*>> allparams;
get_allparams(allparams, networks);
int n = allparams.size();
Expand All @@ -564,13 +577,11 @@ void average_weights(vector<Network> networks) {
for(int j=0; j<m; j++) {
allparams[0][j]->V() += allparams[i][j]->V();
}
for(int j=0; j<m; j++) {
allparams[0][j]->V() = allparams[0][j]->V() * Float(1.0/n);
}
for(int j=0; j<m; j++) {
allparams[i][j]->V() = allparams[0][j]->V();
}
}
for(int j=0; j<m; j++) {
allparams[0][j]->V() = allparams[0][j]->V() * Float(1.0/n);
}
distribute_weights(networks);
}

} // namespace ocropus
5 changes: 3 additions & 2 deletions clstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ void trivial_decode(Classes &cs, Sequence &outputs, int batch,
// single sequence training functions
void mktargets(Tensor<float,2> &seq, Tensor<int,1> &targets, int ndim);

void share_deltas(vector<Network> networks);
void average_weights(vector<Network> networks);
void share_deltas(vector<Network> &networks);
void average_weights(vector<Network> &networks);
void distribute_weights(vector<Network> &networks, int from=0);
}

namespace {
Expand Down
13 changes: 10 additions & 3 deletions clstmhl.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct CLSTMText {
};

struct CLSTMOCR {
unique_ptr<INormalizer> normalizer;
shared_ptr<INormalizer> normalizer;
Network net;
int target_height = 48;
int nclasses = -1;
Expand All @@ -132,7 +132,7 @@ struct CLSTMOCR {
normalizer.reset(make_CenterNormalizer());
normalizer->target_height = target_height;
}
std::wstring train(Tensor<float,2> &raw, const std::wstring &target) {
std::wstring fwdbwd(Tensor<float,2> &raw, const std::wstring &target) {
normalizer->measure(raw);
normalizer->normalize(image, raw);
set_inputs(net, image);
Expand All @@ -144,11 +144,18 @@ struct CLSTMOCR {
for (int t = 0; t < aligned.size(); t++)
net->outputs[t].D() = aligned[t].V() - net->outputs[t].V();
net->backward();
sgd_update(net);
Classes outputs;
trivial_decode(outputs, net->outputs);
return net->codec.decode(outputs);
}
void update() {
sgd_update(net);
}
std::wstring train(Tensor<float,2> &raw, const std::wstring &target) {
std::wstring result = fwdbwd(raw, target);
update();
return result;
}
std::string aligned_utf8() {
Classes outputs;
trivial_decode(outputs, aligned);
Expand Down
2 changes: 1 addition & 1 deletion run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export PS4='
>>>>>>> '
trap "echo TEST FAILED" EXIT
set -x
export seed=0.1423
export seed=0.7733
scons -s -c; rm -f *.o *.a
scons -j 4 clstmocrtrain clstmfiltertrain clstmfilter clstmocr test-lstm
./test-lstm
Expand Down

0 comments on commit 484108c

Please sign in to comment.