diff --git a/modeling.py b/modeling.py index 8b5da0003..ea575220a 100644 --- a/modeling.py +++ b/modeling.py @@ -740,12 +740,12 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: - # `context_layer` = [B*F, N*V] + # `context_layer` = [B*F, N*H] context_layer = tf.reshape( context_layer, [batch_size * from_seq_length, num_attention_heads * size_per_head]) else: - # `context_layer` = [B, F, N*V] + # `context_layer` = [B, F, N*H] context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head])