diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py
index 2b7f9e0e7..2d09cdbc3 100644
--- a/MaxText/layers/embeddings.py
+++ b/MaxText/layers/embeddings.py
@@ -163,24 +163,16 @@ def __call__(
         * (self.max_timescale / self.min_timescale) ** fraction
     )
     position = position[:, :, jnp.newaxis, jnp.newaxis]
-    timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
     sinusoid_inp = position / timescale
     sin = jnp.sin(sinusoid_inp)
     cos = jnp.cos(sinusoid_inp)
-    reshape_tensor = inputs.astype(jnp.float32).reshape(
-        *inputs.shape[:-1], 2, -1
-    )
-    reshape_tensor = jax.numpy.swapaxes(reshape_tensor, -1, -2)
-    first_half = reshape_tensor[..., 0]
-    second_half = reshape_tensor[..., 1]
+    first_half, second_half = jnp.split(inputs, 2, axis=-1)
     first_part = first_half * cos - second_half * sin
     second_part = second_half * cos + first_half * sin
     if self.cast_as_fprop_dtype:
       first_part = first_part.astype(self.fprop_dtype)
       second_part = second_part.astype(self.fprop_dtype)
-    x_out = jnp.stack((first_part, second_part), axis=-1).reshape(
-        *first_part.shape[:-1], -1
-    )
+    x_out = jnp.concatenate((first_part, second_part), axis=-1)
     return x_out
 
 
diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py
index ccbeb7246..38e1a7c62 100644
--- a/MaxText/tests/llama_test.py
+++ b/MaxText/tests/llama_test.py
@@ -109,8 +109,8 @@ def test_rope(self):
     key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position)
 
     # Compare results
-    self.assertTrue(jax.numpy.allclose(llama_output[0], query_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
-    self.assertTrue(jax.numpy.allclose(llama_output[1], key_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
+    self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
+    self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
 
   def test_scaling_rope(self):
     dim_per_head = 128