diff --git a/clstm.cc b/clstm.cc index 0b0cadd..7a76700 100644 --- a/clstm.cc +++ b/clstm.cc @@ -877,7 +877,8 @@ struct GenericNPLSTM : NetworkBase { backward_statemem(state[t], ci[t], gi[t], state, t-1, gf[t]); gradient_clip(state[t].d, gradient_clipping); backward_full(gi[t], WGI, source[t], gradient_clipping); - if (t>0) backward_full(gf[t], WGF, source[t], gradient_clipping); + assert(gf[0].d.maxCoeff()==0); + backward_full(gf[t], WGF, source[t], gradient_clipping); backward_full(go[t], WGO, source[t], gradient_clipping); backward_full(ci[t], WCI, source[t], gradient_clipping); backward_stack1(source[t], inputs[t], out, t-1);