Skip to content

Commit

Permalink
Put deltas into parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbdev committed Sep 23, 2015
1 parent ebab9e2 commit fa76c70
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 36 deletions.
67 changes: 32 additions & 35 deletions clstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,7 @@ inline Float sigmoid(Float x) {

template <class NONLIN>
struct Full : NetworkBase {
Mat W, d_W;
Vec w, d_w;
Params W, w;
int nseq = 0;
int nsteps = 0;
string mykind = string("Full_") + NONLIN::kind;
Expand All @@ -346,27 +345,28 @@ struct Full : NetworkBase {
int no = irequire("noutput");
int ni = irequire("ninput");
randinit(W, no, ni, 0.01);
randinit(w, no, 0.01);
zeroinit(d_W, no, ni);
zeroinit(d_w, no);
randinit(w, no, 1, 0.01);
zeroinit(W.d, no, ni);
zeroinit(w.d, no, 1);
}
void forward() {
outputs.resize(inputs.size());
for (int t = 0; t < inputs.size(); t++) {
outputs[t] = MATMUL(W, inputs[t]);
ADDCOLS(outputs[t], w);
Vec v = COL(w, 0);
ADDCOLS(outputs[t], v);
NONLIN::f(outputs[t]);
}
}
void backward() {
for (int t = outputs.size() - 1; t >= 0; t--) {
NONLIN::df(outputs[t].d, outputs[t]);
inputs[t].d = MATMUL_TR(W, outputs[t].d);
inputs[t].d = MATMUL_TR(W.d, outputs[t].d);
}
int bs = COLS(inputs[0]);
for (int t = 0; t < outputs.size(); t++) {
d_W += MATMUL_RT(outputs[t].d, inputs[t]);
for (int b = 0; b < bs; b++) d_w += COL(outputs[t].d, b);
W.d += MATMUL_RT(outputs[t].d, inputs[t]);
for (int b = 0; b < bs; b++) w.d += COL(outputs[t].d, b);
}
nseq += 1;
nsteps += outputs.size();
Expand All @@ -382,12 +382,12 @@ struct Full : NetworkBase {
;
else
THROW("unknown normalization");
W += lr * d_W;
w += lr * d_w;
W += lr * W.d;
w += lr * w.d;
nsteps = 0;
nseq = 0;
d_W *= momentum;
d_w *= momentum;
W.d *= momentum;
w.d *= momentum;
}
void myweights(const string &prefix, WeightFun f) {
f(prefix + ".W", &W, (Mat *)0);
Expand Down Expand Up @@ -706,12 +706,10 @@ void each(F f, T &a, Args &&... args) {

template <class F = SigmoidNonlin, class G = TanhNonlin, class H = TanhNonlin>
struct GenericNPLSTM : NetworkBase {
#define SEQUENCES gi, gf, go, ci, state
// #define DSEQUENCES gierr, gferr, goerr, cierr, stateerr, outerr
#define WEIGHTS WGI, WGF, WGO, WCI
#define DWEIGHTS DWGI, DWGF, DWGO, DWCI
#define SEQUENCES gi, gf, go, ci, state
Sequence source, SEQUENCES;
Mat WEIGHTS, DWEIGHTS;
Params WEIGHTS;
Float gradient_clipping = 10.0;
int ni, no, nf;
int nsteps = 0;
Expand Down Expand Up @@ -751,7 +749,7 @@ struct GenericNPLSTM : NetworkBase {
clearUpdates();
}
void clearUpdates() {
each([this](Mat &d) { d = Mat::Zero(no, nf); }, DWEIGHTS);
each([this](Params &w) { w.d = Mat::Zero(no, nf); }, WEIGHTS);
}
void resize(int N) {
each([N](Sequence &s) {
Expand All @@ -763,7 +761,6 @@ struct GenericNPLSTM : NetworkBase {
assert(gi.size() == N);
assert(go.size() == N);
}
#define A array()
void forward() {
int N = inputs.size();
resize(N);
Expand Down Expand Up @@ -814,10 +811,10 @@ struct GenericNPLSTM : NetworkBase {
gradient_clip(state, gradient_clipping);
}
for (int t = 0; t < N; t++) {
DWGI += MATMUL_RT(gi[t].d, source[t]);
if (t > 0) DWGF += MATMUL_RT(gf[t].d, source[t]);
DWGO += MATMUL_RT(go[t].d, source[t]);
DWCI += MATMUL_RT(ci[t].d, source[t]);
WGI.d += MATMUL_RT(gi[t].d, source[t]);
if (t > 0) WGF.d += MATMUL_RT(gf[t].d, source[t]);
WGO.d += MATMUL_RT(go[t].d, source[t]);
WCI.d += MATMUL_RT(ci[t].d, source[t]);
}
nsteps += N;
nseq += 1;
Expand All @@ -840,20 +837,20 @@ struct GenericNPLSTM : NetworkBase {
;
else
THROW("unknown normalization");
WGI += lr * DWGI;
WGF += lr * DWGF;
WGO += lr * DWGO;
WCI += lr * DWCI;
DWGI *= momentum;
DWGF *= momentum;
DWGO *= momentum;
DWCI *= momentum;
WGI += lr * WGI.d;
WGF += lr * WGF.d;
WGO += lr * WGO.d;
WCI += lr * WCI.d;
WGI.d *= momentum;
WGF.d *= momentum;
WGO.d *= momentum;
WCI.d *= momentum;
}
void myweights(const string &prefix, WeightFun f) {
f(prefix + ".WGI", &WGI, &DWGI);
f(prefix + ".WGF", &WGF, &DWGF);
f(prefix + ".WGO", &WGO, &DWGO);
f(prefix + ".WCI", &WCI, &DWCI);
f(prefix + ".WGI", &WGI, &WGI.d);
f(prefix + ".WGF", &WGF, &WGF.d);
f(prefix + ".WGO", &WGO, &WGO.d);
f(prefix + ".WCI", &WCI, &WCI.d);
}
virtual void mystates(const string &prefix, StateFun f) {
f(prefix + ".inputs", &inputs);
Expand Down
5 changes: 5 additions & 0 deletions clstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ struct Batch : Mat {
using Mat::Mat;
Mat d;
};
struct Params : Mat {
using Mat::Mat;
Mat d;
bool is_params() { return true; }
};
#endif

// typedef vector<Mat> Sequence;
Expand Down
2 changes: 1 addition & 1 deletion run-uw3-500
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ hidden=100
lrate=1e-4
save_name=uw3-500
report_time=1
gdb --ex run --args \
# gdb --ex run --args \
./clstmocrtrain uw3-train uw3-test

0 comments on commit fa76c70

Please sign in to comment.