Skip to content

Commit

Permalink
tf.matmul allows batch multiplications, removing unnecessary reshapes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686236472
  • Loading branch information
lingvo-bot authored and copybara-github committed Oct 15, 2024
1 parent 48e5034 commit f1f5671
Showing 1 changed file with 54 additions and 74 deletions.
128 changes: 54 additions & 74 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,18 +514,16 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:

# Shape of summed is [sl, tb/sb, sb, hidden_dim].
summed = tf.tanh(inputs.source_vecs + inputs.query_vec)
# logits is of shape [sl * tb/sb * sb, 1]. Computes dot product
# between v with every rows in 'summed'. Then we reshape the
# result to be of shape [sl, tb/sb, sb].
# Compute dot product between v with every rows in 'summed'.
#
# Another equivalent way is to do:
# logits = tf.reduce_sum(summed *
# tf.reshape(v, [1, 1, 1, hidden_dim]), 3)
logits = py_utils.Matmul(
tf.reshape(summed, [-1, p.hidden_dim]),
tf.reshape(inputs.v, [p.hidden_dim, 1]),
)
logits = tf.reshape(logits, tf.shape(summed)[:3])
# logits = tf.reduce_sum(summed * inputs.v, 3)
# or:
# logits = tf.einsum('abcd,d->abc', summed, inputs.v)
# [sl, tb/sb, sb, 1]
logits = tf.matmul(summed, tf.reshape(inputs.v, [p.hidden_dim, 1]))
# [sl, tb/sb, sb]
logits = tf.squeeze(logits, -1)

# Take out the padding states.
source_padding = inputs.Get('source_padding', None)
Expand Down Expand Up @@ -706,14 +704,12 @@ def AttenSameBatchSize(
# [sl, b]
def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:
"""Calculates atten probs with padding."""
# tf.tanh(source_vecs+query_vec) shape [sl, b, hidden_dim]
# [sl, b, hidden_dim]
summed = tf.tanh(inputs.source_vecs + inputs.query_vec)
# [-1, hidden_dim] * [hidden_dim, 1] = [-1, 1]
res = py_utils.Matmul(
tf.reshape(summed, [-1, p.hidden_dim]), tf.expand_dims(inputs.v, 1)
)
# Reshape res to [sl, b]
logits = tf.reshape(res, tf.shape(summed)[:2])
# [sl, b, 1]
logits = tf.matmul(summed, tf.expand_dims(inputs.v, 1))
# [sl, b]
logits = tf.squeeze(logits, -1)

# Take out the padding states. _source_padding is of shape [sl, b].
source_padding = inputs.Get('source_padding', None)
Expand Down Expand Up @@ -2336,22 +2332,16 @@ def AttenLogits(inputs):
"""Generates logits."""
fns = self.fns

def CollapseOutDim(x):
return tf.reshape(x, [-1, tf.shape(x)[-1]])
sl = py_utils.GetShape(inputs.location_feats)[2]
hd = py_utils.GetShape(inputs.location_var)[1]
bs_mult = py_utils.GetShape(inputs.query_vec)[1]
sb = py_utils.GetShape(inputs.query_vec)[2]

# => [sl, sb, hd]
# [sl, tb, location_num_filters]
location_feats = tf.transpose(inputs.location_feats, [2, 0, 1])
location_hidden = py_utils.Matmul(
CollapseOutDim(location_feats), inputs.location_var
)
# [sl, tb, hd]
location_hidden = tf.matmul(location_feats, inputs.location_var)
location_hidden = self.QAct('logits_mul', location_hidden)

sl = py_utils.GetShape(location_feats)[0]
tb = py_utils.GetShape(location_feats)[1]
hd = py_utils.GetShape(inputs.location_var)[1]
location_hidden = tf.reshape(location_hidden, [sl, tb, hd])
sb = py_utils.GetShape(inputs.query_vec)[2]
bs_mult = py_utils.GetShape(inputs.query_vec)[1]
location_hidden = tf.reshape(location_hidden, [sl, bs_mult, sb, hd])

# Shape of summed is [sl, tb/sb, sb, hidden_dim].
Expand All @@ -2362,15 +2352,14 @@ def CollapseOutDim(x):
)
summed = fns.qadd(summed, location_hidden, qout_name='logits_bias')
summed = fns.qtanh(summed)
# logits is of shape [sl * tb/sb * sb, 1]. Computes dot product
# between v with every rows in 'summed'. Then we reshape the
# result to be of shape [sl, tb/sb, sb].
logits = py_utils.Matmul(
tf.reshape(summed, [-1, p.hidden_dim]),
tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]),
# Compute dot product between v with every rows in 'summed'.
# [sl, tb/sb, sb, 1]
logits = tf.matmul(
summed, tf.reshape(inputs.hidden_var, [p.hidden_dim, 1])
)
logits = self.QAct('logits', logits)
logits = tf.reshape(logits, py_utils.GetShape(summed)[:3])
# [sl, tb/sb, sb]
logits = tf.squeeze(logits, -1)
return logits

def AttenLogitsSameBatchSize(inputs: py_utils.NestedMap) -> tf.Tensor:
Expand All @@ -2390,21 +2379,14 @@ def AttenLogitsSameBatchSize(inputs: py_utils.NestedMap) -> tf.Tensor:
Returns:
logits in the shape [sl, batch_size].
"""

def CollapseOutDim(x):
return tf.reshape(x, [-1, tf.shape(x)[-1]])

fns = self.fns
# => [sl, sb, hd]

# [sl, sb, location_num_filters]
location_feats = tf.transpose(inputs.location_feats, [2, 0, 1])
location_hidden = py_utils.Matmul(
CollapseOutDim(location_feats), inputs.location_var
)
# [sl, sb, hd]

location_hidden = tf.matmul(location_feats, inputs.location_var)
location_hidden = self.QAct('logits_mul', location_hidden)
sl = tf.shape(location_feats)[0]
tb = tf.shape(location_feats)[1]
hd = tf.shape(inputs.location_var)[1]
location_hidden = tf.reshape(location_hidden, [sl, tb, hd])

# Shape of summed is [sl, sb, hidden_dim].
summed = fns.qadd(
Expand All @@ -2416,15 +2398,14 @@ def CollapseOutDim(x):
summed = fns.qadd(summed, location_hidden, qout_name='logits_bias')
summed = fns.qtanh(summed)

# logits is of shape [sl * sb, 1]. Computes dot product
# between v with every rows in 'summed'. Then we reshape the
# result to be of shape [sl, tb].
logits = py_utils.Matmul(
tf.reshape(summed, [-1, p.hidden_dim]),
tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]),
# Compute dot product between v with every rows in 'summed'.
# [sl, sb, 1]
logits = tf.matmul(
summed, tf.reshape(inputs.hidden_var, [p.hidden_dim, 1])
)
logits = self.QAct('logits', logits)
logits = tf.reshape(logits, py_utils.GetShape(summed)[:2])
# [sl, sb]
logits = tf.squeeze(logits, -1)
return logits

def Atten(
Expand Down Expand Up @@ -2564,14 +2545,15 @@ def AttenSameBatchSize(
self._ctx_vec = Atten

def EncodeSource(theta, vecs, ctxs):
# vecs is [time, batch, p.source_dim]
time, batch = py_utils.GetShape(vecs, 2)
ctxs = py_utils.HasShape(ctxs, [time, batch, -1])
vecs = py_utils.Matmul(
tf.reshape(vecs, [-1, p.source_dim]), self.QWeight(theta.source_var)
)
vecs = tf.reshape(vecs, [time, batch, -1])
# [time, batch, p.hidden_dim]
vecs = tf.matmul(vecs, self.QWeight(theta.source_var))
vecs = self.QAct('encode_matmul', vecs)

ctxs = py_utils.HasShape(ctxs, [time, batch, -1])
ctxs = tf.transpose(ctxs, [1, 0, 2])

return vecs, ctxs

self._encode_source = EncodeSource
Expand Down Expand Up @@ -2923,13 +2905,14 @@ def __init__(self, params):
p.hard_sigmoid = True

def EncodeSource(theta, vecs, ctxs):
# vecs is [time, batch, p.source_dim]
time, batch = py_utils.GetShape(vecs, 2)
# [time, batch, p.hidden_dim]
vecs = tf.matmul(vecs, theta.source_var)

ctxs = py_utils.HasShape(ctxs, [time, batch, -1])
vecs = py_utils.Matmul(
tf.reshape(vecs, [-1, p.source_dim]), theta.source_var
)
vecs = tf.reshape(vecs, [time, batch, -1])
ctxs = tf.transpose(ctxs, [1, 0, 2])

return vecs, ctxs

self._encode_source = EncodeSource
Expand Down Expand Up @@ -3098,17 +3081,14 @@ def AttenLogits(inputs: py_utils.NestedMap) -> tf.Tensor:
# Shape of summed is [sl, tb/sb, sb, hidden_dim].
summed = tf.tanh(source_vecs + query_vec + energy_b)
hidden_v = inputs.hidden_g * tf.nn.l2_normalize(inputs.hidden_v, axis=0)
# logits is of shape [sl * tb/sb * sb, 1]. Computes dot product
# between v with every rows in 'summed'. Then we reshape the
# result to be of shape [sl, tb/sb, sb].
# Computes dot product between hidden_v with every rows in 'summed'.
#
# Another equivalent way is to do:
# logits = tf.reduce_sum(summed *
# tf.reshape(v, [1, 1, 1, hidden_dim]), 3)
logits = py_utils.Matmul(
tf.reshape(summed, [-1, p.hidden_dim]),
tf.reshape(hidden_v, [p.hidden_dim, 1]),
)
# logits = tf.reduce_sum(summed * hidden_v, 3)
# or:
# logits = tf.einsum('abcd,d->abc', summed, hidden_v)
# [sl, tb/sb, sb, 1]
logits = tf.matmul(summed, tf.reshape(hidden_v, [p.hidden_dim, 1]))
logits += inputs.hidden_b
# [tb, sl].
logits = tf.transpose(tf.reshape(logits, [-1, tb]), [1, 0])
Expand Down

0 comments on commit f1f5671

Please sign in to comment.