Skip to content

Commit

Permalink
Avoid dynamic shape computations in MultiHeadedAttention/PackSource.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687059900
  • Loading branch information
lingvo-bot authored and copybara-github committed Oct 17, 2024
1 parent b26149e commit eb1f947
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,18 @@ def PackSource(
# [time_steps, batch_size, source_dim]
source_vecs = py_utils.HasRank(source_vecs, 3)
time_steps, batch_size = py_utils.GetShape(source_vecs, 2)

# calculate time_steps and batch_size * num_heads, to be used as time and
# batch size for the internal attention, but avoid dynamic size calculations
# whenever possible.
if source_vecs.shape.ndims is None or source_vecs.shape[1] is None:
# batch_size is dynamic; avoid multiplication with num_heads
time_steps_for_internal_attn = time_steps
batch_size_for_internal_attn = -1
else:
time_steps_for_internal_attn = -1
batch_size_for_internal_attn = batch_size * num_heads

# [time_steps, batch_size, context_dim]
source_contexts = py_utils.HasShape(
source_contexts, [time_steps, batch_size, -1]
Expand Down Expand Up @@ -1739,8 +1751,8 @@ def PackSource(
source_vecs = tf.reshape(
source_vecs,
[
-1,
batch_size * num_heads,
time_steps_for_internal_attn,
batch_size_for_internal_attn,
symbolic.ToStatic(p.hidden_dim // num_heads),
],
)
Expand Down Expand Up @@ -1781,7 +1793,11 @@ def PackSource(
# => [time_steps, batch_size * num_heads, context_dim / num_heads]
source_contexts = tf.reshape(
source_contexts,
[-1, batch_size * num_heads, context_dim // num_heads],
[
time_steps_for_internal_attn,
batch_size_for_internal_attn,
context_dim // num_heads,
],
)
source_contexts = gshard_utils.MeshSplit(
source_contexts, p.device_mesh, p.activation_split_dims_mapping
Expand All @@ -1798,7 +1814,8 @@ def PackSource(
source_padding = tf.tile(source_padding, [1, 1, num_heads])
# => [time_steps, batch_size * num_heads]
source_padding = tf.reshape(
source_padding, [-1, batch_size * num_heads]
source_padding,
[time_steps_for_internal_attn, batch_size_for_internal_attn],
)

with tf.name_scope('segment_id'):
Expand All @@ -1809,7 +1826,8 @@ def PackSource(
source_segment_id = tf.tile(source_segment_id, [1, 1, num_heads])
# => [time_steps, batch_size * num_heads]
source_segment_id = tf.reshape(
source_segment_id, [-1, batch_size * num_heads]
source_segment_id,
[time_steps_for_internal_attn, batch_size_for_internal_attn],
)

return self.atten.PackSource(
Expand Down

0 comments on commit eb1f947

Please sign in to comment.