diff --git a/lingvo/core/attention.py b/lingvo/core/attention.py index 633ffe2de..7d16d7b1e 100644 --- a/lingvo/core/attention.py +++ b/lingvo/core/attention.py @@ -514,18 +514,16 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: # Shape of summed is [sl, tb/sb, sb, hidden_dim]. summed = tf.tanh(inputs.source_vecs + inputs.query_vec) - # logits is of shape [sl * tb/sb * sb, 1]. Computes dot product - # between v with every rows in 'summed'. Then we reshape the - # result to be of shape [sl, tb/sb, sb]. + # Compute dot product between v with every rows in 'summed'. # # Another equivalent way is to do: - # logits = tf.reduce_sum(summed * - # tf.reshape(v, [1, 1, 1, hidden_dim]), 3) - logits = py_utils.Matmul( - tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(inputs.v, [p.hidden_dim, 1]), - ) - logits = tf.reshape(logits, tf.shape(summed)[:3]) + # logits = tf.reduce_sum(summed * inputs.v, 3) + # or: + # logits = tf.einsum('abcd,d->abc', summed, inputs.v) + # [sl, tb/sb, sb, 1] + logits = tf.matmul(summed, tf.reshape(inputs.v, [p.hidden_dim, 1])) + # [sl, tb/sb, sb] + logits = tf.squeeze(logits, -1) # Take out the padding states. source_padding = inputs.Get('source_padding', None) @@ -706,14 +704,12 @@ def AttenSameBatchSize( # [sl, b] def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: """Calculates atten probs with padding.""" - # tf.tanh(source_vecs+query_vec) shape [sl, b, hidden_dim] + # [sl, b, hidden_dim] summed = tf.tanh(inputs.source_vecs + inputs.query_vec) - # [-1, hidden_dim] * [hidden_dim, 1] = [-1, 1] - res = py_utils.Matmul( - tf.reshape(summed, [-1, p.hidden_dim]), tf.expand_dims(inputs.v, 1) - ) - # Reshape res to [sl, b] - logits = tf.reshape(res, tf.shape(summed)[:2]) + # [sl, b, 1] + logits = tf.matmul(summed, tf.expand_dims(inputs.v, 1)) + # [sl, b] + logits = tf.squeeze(logits, -1) # Take out the padding states. _source_padding is of shape [sl, b]. source_padding = inputs.Get('source_padding', None) @@ -2336,22 +2332,16 @@ def AttenLogits(inputs): """Generates logits.""" fns = self.fns - def CollapseOutDim(x): - return tf.reshape(x, [-1, tf.shape(x)[-1]]) + sl = py_utils.GetShape(inputs.location_feats)[2] + hd = py_utils.GetShape(inputs.location_var)[1] + bs_mult = py_utils.GetShape(inputs.query_vec)[1] + sb = py_utils.GetShape(inputs.query_vec)[2] - # => [sl, sb, hd] + # [sl, tb, location_num_filters] location_feats = tf.transpose(inputs.location_feats, [2, 0, 1]) - location_hidden = py_utils.Matmul( - CollapseOutDim(location_feats), inputs.location_var - ) + # [sl, tb, hd] + location_hidden = tf.matmul(location_feats, inputs.location_var) location_hidden = self.QAct('logits_mul', location_hidden) - - sl = py_utils.GetShape(location_feats)[0] - tb = py_utils.GetShape(location_feats)[1] - hd = py_utils.GetShape(inputs.location_var)[1] - location_hidden = tf.reshape(location_hidden, [sl, tb, hd]) - sb = py_utils.GetShape(inputs.query_vec)[2] - bs_mult = py_utils.GetShape(inputs.query_vec)[1] location_hidden = tf.reshape(location_hidden, [sl, bs_mult, sb, hd]) # Shape of summed is [sl, tb/sb, sb, hidden_dim]. @@ -2362,15 +2352,14 @@ def CollapseOutDim(x): ) summed = fns.qadd(summed, location_hidden, qout_name='logits_bias') summed = fns.qtanh(summed) - # logits is of shape [sl * tb/sb * sb, 1]. Computes dot product - # between v with every rows in 'summed'. Then we reshape the - # result to be of shape [sl, tb/sb, sb]. - logits = py_utils.Matmul( - tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]), + # Compute dot product between v with every rows in 'summed'. + # [sl, tb/sb, sb, 1] + logits = tf.matmul( + summed, tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]) ) logits = self.QAct('logits', logits) - logits = tf.reshape(logits, py_utils.GetShape(summed)[:3]) + # [sl, tb/sb, sb] + logits = tf.squeeze(logits, -1) return logits def AttenLogitsSameBatchSize(inputs: py_utils.NestedMap) -> tf.Tensor: @@ -2390,21 +2379,14 @@ def AttenLogitsSameBatchSize(inputs: py_utils.NestedMap) -> tf.Tensor: Returns: logits in the shape [sl, batch_size]. """ - - def CollapseOutDim(x): - return tf.reshape(x, [-1, tf.shape(x)[-1]]) - fns = self.fns - # => [sl, sb, hd] + + # [sl, sb, location_num_filters] location_feats = tf.transpose(inputs.location_feats, [2, 0, 1]) - location_hidden = py_utils.Matmul( - CollapseOutDim(location_feats), inputs.location_var - ) + # [sl, sb, hd] + + location_hidden = tf.matmul(location_feats, inputs.location_var) location_hidden = self.QAct('logits_mul', location_hidden) - sl = tf.shape(location_feats)[0] - tb = tf.shape(location_feats)[1] - hd = tf.shape(inputs.location_var)[1] - location_hidden = tf.reshape(location_hidden, [sl, tb, hd]) # Shape of summed is [sl, sb, hidden_dim]. summed = fns.qadd( @@ -2416,15 +2398,14 @@ def CollapseOutDim(x): summed = fns.qadd(summed, location_hidden, qout_name='logits_bias') summed = fns.qtanh(summed) - # logits is of shape [sl * sb, 1]. Computes dot product - # between v with every rows in 'summed'. Then we reshape the - # result to be of shape [sl, tb]. - logits = py_utils.Matmul( - tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]), + # Compute dot product between v with every rows in 'summed'. + # [sl, sb, 1] + logits = tf.matmul( + summed, tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]) ) logits = self.QAct('logits', logits) - logits = tf.reshape(logits, py_utils.GetShape(summed)[:2]) + # [sl, sb] + logits = tf.squeeze(logits, -1) return logits def Atten( @@ -2564,14 +2545,15 @@ def AttenSameBatchSize( self._ctx_vec = Atten def EncodeSource(theta, vecs, ctxs): + # vecs is [time, batch, p.source_dim] time, batch = py_utils.GetShape(vecs, 2) - ctxs = py_utils.HasShape(ctxs, [time, batch, -1]) - vecs = py_utils.Matmul( - tf.reshape(vecs, [-1, p.source_dim]), self.QWeight(theta.source_var) - ) - vecs = tf.reshape(vecs, [time, batch, -1]) + # [time, batch, p.hidden_dim] + vecs = tf.matmul(vecs, self.QWeight(theta.source_var)) vecs = self.QAct('encode_matmul', vecs) + + ctxs = py_utils.HasShape(ctxs, [time, batch, -1]) ctxs = tf.transpose(ctxs, [1, 0, 2]) + return vecs, ctxs self._encode_source = EncodeSource @@ -2923,13 +2905,14 @@ def __init__(self, params): p.hard_sigmoid = True def EncodeSource(theta, vecs, ctxs): + # vecs is [time, batch, p.source_dim] time, batch = py_utils.GetShape(vecs, 2) + # [time, batch, p.hidden_dim] + vecs = tf.matmul(vecs, theta.source_var) + ctxs = py_utils.HasShape(ctxs, [time, batch, -1]) - vecs = py_utils.Matmul( - tf.reshape(vecs, [-1, p.source_dim]), theta.source_var - ) - vecs = tf.reshape(vecs, [time, batch, -1]) ctxs = tf.transpose(ctxs, [1, 0, 2]) + return vecs, ctxs self._encode_source = EncodeSource @@ -3098,17 +3081,14 @@ def AttenLogits(inputs: py_utils.NestedMap) -> tf.Tensor: # Shape of summed is [sl, tb/sb, sb, hidden_dim]. summed = tf.tanh(source_vecs + query_vec + energy_b) hidden_v = inputs.hidden_g * tf.nn.l2_normalize(inputs.hidden_v, axis=0) - # logits is of shape [sl * tb/sb * sb, 1]. Computes dot product - # between v with every rows in 'summed'. Then we reshape the - # result to be of shape [sl, tb/sb, sb]. + # Computes dot product between hidden_v with every rows in 'summed'. # # Another equivalent way is to do: - # logits = tf.reduce_sum(summed * - # tf.reshape(v, [1, 1, 1, hidden_dim]), 3) - logits = py_utils.Matmul( - tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(hidden_v, [p.hidden_dim, 1]), - ) + # logits = tf.reduce_sum(summed * hidden_v, 3) + # or: + # logits = tf.einsum('abcd,d->abc', summed, hidden_v) + # [sl, tb/sb, sb, 1] + logits = tf.matmul(summed, tf.reshape(hidden_v, [p.hidden_dim, 1])) logits += inputs.hidden_b # [tb, sl]. logits = tf.transpose(tf.reshape(logits, [-1, tb]), [1, 0])