diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 2b7f9e0e7..2d09cdbc3 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -163,24 +163,16 @@ def __call__( * (self.max_timescale / self.min_timescale) ** fraction ) position = position[:, :, jnp.newaxis, jnp.newaxis] - timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] sinusoid_inp = position / timescale sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) - reshape_tensor = inputs.astype(jnp.float32).reshape( - *inputs.shape[:-1], 2, -1 - ) - reshape_tensor = jax.numpy.swapaxes(reshape_tensor, -1, -2) - first_half = reshape_tensor[..., 0] - second_half = reshape_tensor[..., 1] + first_half, second_half = jnp.split(inputs, 2, axis=-1) first_part = first_half * cos - second_half * sin second_part = second_half * cos + first_half * sin if self.cast_as_fprop_dtype: first_part = first_part.astype(self.fprop_dtype) second_part = second_part.astype(self.fprop_dtype) - x_out = jnp.stack((first_part, second_part), axis=-1).reshape( - *first_part.shape[:-1], -1 - ) + x_out = jnp.concatenate((first_part, second_part), axis=-1) return x_out diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index ccbeb7246..38e1a7c62 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -109,8 +109,8 @@ def test_rope(self): key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position) # Compare results - self.assertTrue(jax.numpy.allclose(llama_output[0], query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) - self.assertTrue(jax.numpy.allclose(llama_output[1], key_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) def test_scaling_rope(self): dim_per_head = 128