From c7700b0a4132dd95b5890de227089febbd1cd5df Mon Sep 17 00:00:00 2001 From: Rafi Witten Date: Fri, 12 Jan 2024 17:05:20 +0000 Subject: [PATCH] Single v4-8 step time from 1.209 to 1.199 secs. (No convergence data since it is actually bit-wise identical locally) Optimiization From Blake! --- MaxText/layers/embeddings.py | 12 ++---------- MaxText/tests/llama_test.py | 4 ++-- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 6fa1fafb4..c88dcf434 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -162,22 +162,14 @@ def __call__( * (self.max_timescale / self.min_timescale) ** fraction ) position = position[:, :, jnp.newaxis, jnp.newaxis] - timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] sinusoid_inp = position / timescale sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) - reshape_tensor = inputs.astype(jnp.float32).reshape( - *inputs.shape[:-1], 2, -1 - ) - reshape_tensor = jax.numpy.swapaxes(reshape_tensor, -1, -2) - first_half = reshape_tensor[..., 0] - second_half = reshape_tensor[..., 1] + first_half, second_half = jnp.split(inputs, 2, axis=-1) first_part = first_half * cos - second_half * sin second_part = second_half * cos + first_half * sin if self.cast_as_fprop_dtype: first_part = first_part.astype(self.fprop_dtype) second_part = second_part.astype(self.fprop_dtype) - x_out = jnp.stack((first_part, second_part), axis=-1).reshape( - *first_part.shape[:-1], -1 - ) + x_out = jnp.concatenate((first_part, second_part), axis=-1) return x_out diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index db7bca697..d0f814ca5 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -109,8 +109,8 @@ def test_rope(self): key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position) # Compare results - self.assertTrue(jax.numpy.allclose(llama_output[0], query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) - self.assertTrue(jax.numpy.allclose(llama_output[1], key_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) if __name__ == '__main__': unittest.main()