Skip to content

Commit

Permalink
Fix a bug in RotaryPositionalEmbeddingLayer.
Browse files Browse the repository at this point in the history
 The timescale (float32) was cast to int32 (when passed from ROPE) which loses precision.

PiperOrigin-RevId: 671851096
  • Loading branch information
lingvo-bot authored and copybara-github committed Sep 6, 2024
1 parent 447109a commit f910d4f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 3 additions & 1 deletion lingvo/core/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3515,7 +3515,9 @@ def FProp(self, theta, inputs, position=None):
position = position[:, :, tf.newaxis, tf.newaxis]
timescale = timescale[tf.newaxis, tf.newaxis, tf.newaxis, :]

sinusoid_inp = position / tf.cast(timescale, position.dtype)
sinusoid_inp = tf.cast(
tf.cast(position, timescale.dtype) / timescale, inputs.dtype
)
sin = tf.sin(sinusoid_inp)
cos = tf.cos(sinusoid_inp)
first_half, second_half = tf.split(inputs, 2, axis=-1)
Expand Down
11 changes: 8 additions & 3 deletions lingvo/core/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4385,7 +4385,8 @@ def testSinusoidalPositionalEmbeddingLayer(self):
math.cos(p / 2 * math.pi)] for p in range(4)]
self.assertAllClose(actual_position_embs, expected_output)

def testRotaryPositionalEmbeddingLayer(self):
@parameterized.named_parameters(('default', False), ('has_position', True))
def testRotaryPositionalEmbeddingLayer(self, has_position=False):
with self.session(use_gpu=False):
p = layers.RotaryPositionalEmbeddingLayer.Params()
p.name = 'position_emb'
Expand All @@ -4394,15 +4395,19 @@ def testRotaryPositionalEmbeddingLayer(self):
p.embedding_dim = 4
seq_length = 5
inputs = tf.ones([1, seq_length, 1, p.embedding_dim])
if has_position:
positions = tf.range(seq_length)[tf.newaxis, :]
else:
positions = None

pos_emb_layer = p.Instantiate()
self.evaluate(tf.global_variables_initializer())
position_embs = pos_emb_layer.FPropDefaultTheta(inputs)
position_embs = pos_emb_layer.FPropDefaultTheta(inputs, positions)
position_embs = tf.squeeze(position_embs, axis=[0, 2])
actual_position_embs, = self.evaluate([position_embs])

expected_output = [
[1., 1., 1., 1.],
[1.0, 1.0, 1.0, 1.0],
[-0.30116868, 0.5603883, 1.3817732, 1.2984471],
[-1.3254442, 0.04166961, 0.4931506, 1.4135995],
[-1.1311125, -0.48293126, -0.8488725, 1.3292018],
Expand Down

0 comments on commit f910d4f

Please sign in to comment.