Skip to content

Commit c7700b0

Browse files
author
Rafi Witten
committed
Single v4-8 step time from 1.209 to 1.199 secs. (No convergence data since it is actually bit-wise identical locally)
Optimiization From Blake!
1 parent b947cdd commit c7700b0

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

MaxText/layers/embeddings.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,22 +162,14 @@ def __call__(
162162
* (self.max_timescale / self.min_timescale) ** fraction
163163
)
164164
position = position[:, :, jnp.newaxis, jnp.newaxis]
165-
timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
166165
sinusoid_inp = position / timescale
167166
sin = jnp.sin(sinusoid_inp)
168167
cos = jnp.cos(sinusoid_inp)
169-
reshape_tensor = inputs.astype(jnp.float32).reshape(
170-
*inputs.shape[:-1], 2, -1
171-
)
172-
reshape_tensor = jax.numpy.swapaxes(reshape_tensor, -1, -2)
173-
first_half = reshape_tensor[..., 0]
174-
second_half = reshape_tensor[..., 1]
168+
first_half, second_half = jnp.split(inputs, 2, axis=-1)
175169
first_part = first_half * cos - second_half * sin
176170
second_part = second_half * cos + first_half * sin
177171
if self.cast_as_fprop_dtype:
178172
first_part = first_part.astype(self.fprop_dtype)
179173
second_part = second_part.astype(self.fprop_dtype)
180-
x_out = jnp.stack((first_part, second_part), axis=-1).reshape(
181-
*first_part.shape[:-1], -1
182-
)
174+
x_out = jnp.concatenate((first_part, second_part), axis=-1)
183175
return x_out

MaxText/tests/llama_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def test_rope(self):
109109
key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position)
110110

111111
# Compare results
112-
self.assertTrue(jax.numpy.allclose(llama_output[0], query_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
113-
self.assertTrue(jax.numpy.allclose(llama_output[1], key_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
112+
self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
113+
self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
114114

115115
if __name__ == '__main__':
116116
unittest.main()

0 commit comments

Comments
 (0)