diff --git a/lingvo/core/py_utils.py b/lingvo/core/py_utils.py index 836738490..7a342c784 100644 --- a/lingvo/core/py_utils.py +++ b/lingvo/core/py_utils.py @@ -7034,9 +7034,6 @@ def MultiTaskProjection( b_task = 'b' t_task = 't' - # [num_tasks] or [batch, num_tasks] or [batch, time, num_tasks] - tasks_onehot = tf.one_hot(tasks, num_tasks, axis=-1, dtype=inputs.dtype) - # Einsum axis names: # b - batch (b_task, if the corresponding tensor has batch dimension) # t - time (t_input and t_task, if corresponding tensors have time dimension) @@ -7050,9 +7047,7 @@ def MultiTaskProjection( weights = quant_layer.ToAqtWeight(w_q_name, weights, feature_axis=-1) # select.. # [{batch,} {time,} input_dim, output_dim] - selected_weights = tf.einsum( - f'{b_task}{t_task}k,kio->{b_task}{t_task}io', tasks_onehot, weights - ) + selected_weights = tf.gather(weights, tasks) if qat_output: # .. and multiply # [batch, {time,} output_dim] @@ -7077,10 +7072,11 @@ def MultiTaskProjection( all_projected = tf.einsum(f'b{t_input}i,kio->b{t_input}ko', inputs, weights) # .. and select # [batch, {time,} output_dim] - out = tf.einsum( - f'b{t_input}ko,{b_task}{t_task}k->b{t_input}o', + out = tf.gather( all_projected, - tasks_onehot, + tasks, + axis=1 if time_size is None else 2, # where's 'task' in all_projected? + batch_dims=GetRank(tasks), # do we have batch and/or time in tasks? ) else: raise ValueError( @@ -7088,9 +7084,7 @@ def MultiTaskProjection( ) if biases is not None: # [{batch,} {time,} output_dim] - bias = tf.einsum( - f'{b_task}{t_task}k,ko->{b_task}{t_task}o', tasks_onehot, biases - ) + bias = tf.gather(biases, tasks) # If `out` has time dimension (`bto`), and `tasks` has batch but no time # (`bo`), then we need to expand the bias on the second dimension for