-
Notifications
You must be signed in to change notification settings - Fork 0
/
RNN
23 lines (18 loc) · 1.2 KB
/
RNN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
要实现一个RNN层,需要两个部分:rnn_cell_impl.py里的RNNCell和rnn.py里的辅助函数
#RNNCell class
1.任何RNN cell(BasicRNN Cell,GRU cell或者LSTM Cell)都继承至这个类
2.这个类继承自tensorflow.python.layers.base.Layer类,主要是为了overriding __call__这个函数接口
3. 在__init__里并没有向图里添加变量和节点
4. 调用call()的时候,才向计算图里添加变量和op,这往往发生在dynamic_rnn的while_loop调用它的body函数时
#dynamic_rnn函数
1. 从max_time的维度对input用RNNCell进行计算
2. 每一个time step得到的RNNCell state被写进TensorArray里
3. 对于一个batch中已经超出sequence length的time step,输出的state是复制前一个time step的状态,output则为zero tensor
BasicRNN Cell
1.输出的output就是state
GRU Cell
1.输出的output就是state
BasicLSTM Cell
state是一个namedtuple:LSTMStateTuple,其中'c'代表cell state, 'h'代表hidden state
1. 因为forget gate,input_gate, output_gate还有cell state的new input都依赖于输入和h,所以他们被用同一个linear layer来计算
2. 返回的(output, next_state)里,output是h,next_state是包含了c和h的LSTMStateTuple