-
Notifications
You must be signed in to change notification settings - Fork 28
/
EncDec.hpp
144 lines (125 loc) · 4.35 KB
/
EncDec.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#pragma once
#include "LSTM.hpp"
#include "Vocabulary.hpp"
#include "SoftMax.hpp"
#include "BlackOut.hpp"
class EncDec{
public:
class Data;
class Grad;
class DecCandidate;
class ThreadArg;
EncDec(Vocabulary& sourceVoc_, Vocabulary& targetVoc_,
std::vector<EncDec::Data*>& trainData_, std::vector<EncDec::Data*>& devData_,
const int inputDim, const int hiddenDim,
const bool useBlackout_);
bool useBlackout;
Rand rnd;
Vocabulary& sourceVoc;
Vocabulary& targetVoc;
std::vector<EncDec::Data*>& trainData;
std::vector<EncDec::Data*>& devData;
LSTM enc, dec;
SoftMax softmax;
BlackOut blackout;
MatD sourceEmbed;
MatD targetEmbed;
VecD zeros;
std::vector<std::vector<LSTM::State*> > encStateDev, decStateDev;
void encode(const std::vector<int>& src, std::vector<LSTM::State*>& encState);
void translate(const std::vector<int>& src, const int beam = 1, const int maxLength = 100, const int showNum = 1);
bool translate(std::vector<int>& output, const std::vector<int>& src, const int beam = 1, const int maxLength = 100);
Real calcLoss(EncDec::Data* data, std::vector<LSTM::State*>& encState, std::vector<LSTM::State*>& decState);
Real calcPerplexity(EncDec::Data* data, std::vector<LSTM::State*>& encState, std::vector<LSTM::State*>& decState);
void gradCheck(EncDec::Data* data, std::vector<LSTM::State*>& encState, std::vector<LSTM::State*>& decState, EncDec::Grad& grad);
void gradCheck(EncDec::Data* data, std::vector<LSTM::State*>& encState, std::vector<LSTM::State*>& decState, MatD& param, const MatD& grad);
void train(EncDec::Data* data, std::vector<LSTM::State*>& encState, std::vector<LSTM::State*>& decState, EncDec::Grad& grad, Real& loss);
void trainOpenMP(const Real learningRate, const int miniBatchSize = 1, const int numThreads = 1);
void save(const std::string& fileName);
void load(const std::string& fileName);
static void demo(const std::string& srcTrain, const std::string& tgtTrain, const std::string& srcDev, const std::string& tgtDev);
};
class EncDec::Data{
public:
std::vector<int> src, tgt;
};
class EncDec::Grad{
public:
std::unordered_map<int, VecD> sourceEmbed, targetEmbed;
LSTM::Grad lstmSrcGrad;
LSTM::Grad lstmTgtGrad;
SoftMax::Grad softmaxGrad;
BlackOut::Grad blackoutGrad;
BlackOut::State blackoutState;
void init(){
this->sourceEmbed.clear();
this->targetEmbed.clear();
this->lstmSrcGrad.init();
this->lstmTgtGrad.init();
this->softmaxGrad.init();
this->blackoutGrad.init();
}
Real norm(){
Real res = this->lstmSrcGrad.norm()+this->lstmTgtGrad.norm()+this->softmaxGrad.norm()+this->blackoutGrad.norm();
for (auto it = this->sourceEmbed.begin(); it != this->sourceEmbed.end(); ++it){
res += it->second.squaredNorm();
}
for (auto it = this->targetEmbed.begin(); it != this->targetEmbed.end(); ++it){
res += it->second.squaredNorm();
}
return res;
}
void operator += (const EncDec::Grad& grad){
this->lstmSrcGrad += grad.lstmSrcGrad;
this->lstmTgtGrad += grad.lstmTgtGrad;
this->softmaxGrad += grad.softmaxGrad;
this->blackoutGrad += grad.blackoutGrad;
for (auto it = grad.sourceEmbed.begin(); it != grad.sourceEmbed.end(); ++it){
if (this->sourceEmbed.count(it->first)){
this->sourceEmbed.at(it->first) += it->second;
}
else {
this->sourceEmbed[it->first] = it->second;
}
}
for (auto it = grad.targetEmbed.begin(); it != grad.targetEmbed.end(); ++it){
if (this->targetEmbed.count(it->first)){
this->targetEmbed.at(it->first) += it->second;
}
else {
this->targetEmbed[it->first] = it->second;
}
}
}
};
class EncDec::DecCandidate{
public:
DecCandidate():
score(0.0), stop(false)
{}
Real score;
std::vector<int> tgt;
std::vector<LSTM::State*> decState;
bool stop;
};
class EncDec::ThreadArg{
public:
ThreadArg(EncDec& encdec_):
encdec(encdec_), loss(0.0)
{
this->grad.lstmSrcGrad = LSTM::Grad(this->encdec.enc);
this->grad.lstmTgtGrad = LSTM::Grad(this->encdec.dec);
if (this->encdec.useBlackout){
this->grad.blackoutState = BlackOut::State(this->encdec.blackout);
this->grad.blackoutGrad = BlackOut::Grad();
}
else {
this->grad.softmaxGrad = SoftMax::Grad(this->encdec.softmax);
}
};
int beg, end;
EncDec& encdec;
EncDec::Grad grad;
Real loss;
std::vector<LSTM::State*> encState, decState;
};