Skip to content

Commit

Permalink
Use tf.gather instead of tf.einsum to gather slices of tensors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665950459
  • Loading branch information
lingvo-bot authored and copybara-github committed Aug 21, 2024
1 parent 2474653 commit 447109a
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions lingvo/core/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7034,9 +7034,6 @@ def MultiTaskProjection(
b_task = 'b'
t_task = 't'

# [num_tasks] or [batch, num_tasks] or [batch, time, num_tasks]
tasks_onehot = tf.one_hot(tasks, num_tasks, axis=-1, dtype=inputs.dtype)

# Einsum axis names:
# b - batch (b_task, if the corresponding tensor has batch dimension)
# t - time (t_input and t_task, if corresponding tensors have time dimension)
Expand All @@ -7050,9 +7047,7 @@ def MultiTaskProjection(
weights = quant_layer.ToAqtWeight(w_q_name, weights, feature_axis=-1)
# select..
# [{batch,} {time,} input_dim, output_dim]
selected_weights = tf.einsum(
f'{b_task}{t_task}k,kio->{b_task}{t_task}io', tasks_onehot, weights
)
selected_weights = tf.gather(weights, tasks)
if qat_output:
# .. and multiply
# [batch, {time,} output_dim]
Expand All @@ -7077,20 +7072,19 @@ def MultiTaskProjection(
all_projected = tf.einsum(f'b{t_input}i,kio->b{t_input}ko', inputs, weights)
# .. and select
# [batch, {time,} output_dim]
out = tf.einsum(
f'b{t_input}ko,{b_task}{t_task}k->b{t_input}o',
out = tf.gather(
all_projected,
tasks_onehot,
tasks,
axis=1 if time_size is None else 2, # where's 'task' in all_projected?
batch_dims=GetRank(tasks), # do we have batch and/or time in tasks?
)
else:
raise ValueError(
'einsum_order must be select_and_multiply or multiply_and_select.'
)
if biases is not None:
# [{batch,} {time,} output_dim]
bias = tf.einsum(
f'{b_task}{t_task}k,ko->{b_task}{t_task}o', tasks_onehot, biases
)
bias = tf.gather(biases, tasks)

# If `out` has time dimension (`bto`), and `tasks` has batch but no time
# (`bo`), then we need to expand the bias on the second dimension for
Expand Down

0 comments on commit 447109a

Please sign in to comment.