-
Notifications
You must be signed in to change notification settings - Fork 0
/
Seq2Seq
60 lines (54 loc) · 5.26 KB
/
Seq2Seq
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
Helper Class
sample函数:根据output(一般是经过output layer投影的rnn output)对candidate进行sampling。
next_inputs:计算下一个step的输入向量。比如TrainingHelpler就将input转换为TensorArray,再从其中read。
BasicDecoderOutput Class
数据储存功能:存着每一个step的rnn output(经过output layer投影)和按照helper的sample策略得到的词的id
是一个namedtuple,由rnn_output(形状为 b,vocab_size)和sample_id (形状为b) 两个tensor组成。
Decoder Class
执行一个step的解码:将各个组件使用起来对一个batch的数据进行一次decode
1.包含了RNN cell,helper和output layer这3个部分。RNN cell一般是AttentionWrapper。
2.核心函数 step
输入:前一step中helper读取的input,前一step的state(RNN cell state或者AttentionWrapperState)。
中间:用RNN cell产生新的hidden state,或者是用AttentionWrapper对前一个step的AttentionWrapperState产生新的AttentionWrapperState。经过output layer投影,再用helper进行sampling 得到output和下一个step的input,并由helper判断是否结束decode
输出:
1. BasicDecoderOutput;
2. 下一个step的input和一个长度为b的finished的flag向量(由helper计算);
3. next_state:通常是直接传递RNN cell state或者AttentionWrapperState
step函数并不管当前batch中那些entry是finished的,因为对已经finish的进行下一步decode也不影响什么
AttentionMechanism Class
采用memory + query方式来算出alignment score
memory: 一般是rnn的hidden states,shape = [batch,max_time,rnn_num_units]
_values: mask之后的memory
_keys:将_values经过投影之后得到的tensor,形状为[batch,max_time,attn_num_units],投影采用的是tf.layers.Dense(bias为0)。也是被mask过了的
__call__():将query进行投影,然后与_keys算alignments,alignments形状为[batch_size, max_time],表示每一个hidden state应得的unnormalized分数。
_BaseAttentionMechanism Class
AttentionMechanism Class的基类,会对values,keys分别进行mask,还会对probability_fn进行包装,在计算score时将padding部分的注意力分数mask为0
AttentionWrapperState Class
cell_state:rnn cell在前一步的原始state,形状为 batch,rnn_output_size
attention: 如果attention layer是None的话,attention就是context向量
alignments:形状为 b,s
alignment_history:TensorArray,每一个tensor的形状为b,s,对应那个step的alignments向量
AttenWrapper Class
1.是RNNCell的subclass。
2.call() 函数
进行一次decode,不是像论文中那样是先算各个hidden state的weight。weighted sum的过程由上一个time step计算出decoder的hidden state之后直接进行,再通过output将上一个time step的hidden state和weighted sum传递给当前的step。
1. Mix the `inputs` and previous step's `attention(context vector, weighted sum)` output via`cell_input_fn`.
2. Call the wrapped `cell` with this input and its current state.
3. Score the cell's output with `attention_mechanism`.
4. Calculate the alignments by passing the score through the `normalizer`. alignments向量是归一化过后的向量,[b,s]
5. Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory).
6. Calculate the attention 向量: 如果attention layer是None的话,attention就是context向量。否则 concatenating the cell output
and context through the attention layer (a linear layer with `attention_layer_size` outputs)。
7. 输出: 对应原rnn cell output:如果output_attention是true,就输出attention向量,否则输出rnn cell output。
对应原rnn cell state:一个AttentionWrapperState的instance。这个state在之后的计算中会被传给helper做sample和next_input计算用。但helper并没有用到state,所以不知道API设计成这样是要干什么 :(
4 alignment history的保存机制
1.初始化时(使用zero state函数初始化)是一个size为0的TensorArray
2.动态解码的过程中,在每一个step时,AttenWrapper向初始化时的TensorArray中写入当前step计算出的alignments向量
3.解码结束后, alignment history被保存在AttentionWrapperState,一般被直接传递给dynamic_decode返回的final_state,未经过转换。
dynamic_decode 方法
1. 使用while_loop函数来创造动态图,在每一time step中计算该batch中的每一个entry是否finish以及它的final length是多少,将input转换为Tensor
Array,使用TensorArray来记录每一个time step产生的结果。
2. 返回的final_outputs, final_state, final_sequence_lengths中:
final_outputs:把每一step的BasicDecoderOutput的rnn_output(形状为 b,vocab_size)和sample_id (形状为b) 两个tensor写到对应的TensorArray中,并进行stack转化,得到名字还是为rnn_output和sample_id的tensor。形状为 b, s 和 b,s,vocab_size
final_state:一般为AttentionWrapperState,其中的cell_state,attention,alignments都是最后一个step时的值。alignment_history是一个TensorArray,记录了每一个step的alignments向量
final_sequence_lengths:形状为 b,一般不能直接使用。GreedyEmbeddingHelper算出来的final_sequence_lengths是正确的,TrainingHelper算出来的一般是最大句子长度。