diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 76f7026a3..30d9ad315 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -83,7 +83,7 @@ jobs: - name: Test llama2_decode.py run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset load_from_other_directory=gs://maxtext-llama/llama2-7b/decode-ckpt-maxtext load_from_other_directory_step=0 per_device_batch_size=1 model_name='llama2-7b' assets_path=gs://maxtext-llama/llama2-7b ici_tensor_parallelism=4 steps=1 max_prefill_predict_length=4 max_target_length=16 async_checkpointing=false prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share."' + 'bash end_to_end/llama_load_and_test.sh' # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job gpu: diff --git a/MaxText/convert_llama_ckpt.py b/MaxText/convert_llama_ckpt.py index d177ad725..d06c26eba 100644 --- a/MaxText/convert_llama_ckpt.py +++ b/MaxText/convert_llama_ckpt.py @@ -43,6 +43,11 @@ jax.config.update('jax_platform_name', 'cpu') +def permute_to_match_maxtext_rope(arr): + evens = arr[..., ::2] + odds = arr[..., 1::2] + return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + MODEL_PARAMS_DICT = { '70b': { 'num_layers': 80, @@ -172,6 +177,8 @@ def convert(base_model_path, maxtext_model_path, model_size): wq = np.reshape(wq, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) + wq = permute_to_match_maxtext_rope(wq) + wk = permute_to_match_maxtext_rope(wk) w_post = np.concatenate( [ @@ -223,10 +230,6 @@ def convert(base_model_path, maxtext_model_path, model_size): jax_weights['decoder']['decoder']['self_attention'] = self_attention - - - - layer_weight['mlp']['wi_0']['kernel'] = np.array(layer_weight['mlp']['wi_0']['kernel']) layer_weight['mlp']['wi_1']['kernel'] = np.array(layer_weight['mlp']['wi_1']['kernel']) layer_weight['mlp']['wo']['kernel'] = np.array(layer_weight['mlp']['wo']['kernel']) @@ -250,8 +253,6 @@ def convert(base_model_path, maxtext_model_path, model_size): #convert all weights to jax.numpy jax_weights = jax.tree_map(jnp.array, jax_weights) - print(f"jax_weights = {jax_weights}") - #dummy configs for the checkpoint_manager step_number_to_save_new_ckpt = 0 enable_checkpointing=True @@ -272,10 +273,7 @@ def convert(base_model_path, maxtext_model_path, model_size): params=jax_weights, tx=None, # type: ignore opt_state={} -) - - - print(f"Trainstate after replacing params with jax_weights={state_new}") + ) if checkpoint_manager is not None: if checkpoint_manager.save(step_number_to_save_new_ckpt, state_new): diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 326314f17..d6094cf3c 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -41,7 +41,7 @@ PRNGKey = common_types.PRNGKey DenseGeneral = linears.DenseGeneral -LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding +RotaryEmbedding = embeddings.RotaryEmbedding NdInitializer = initializers.NdInitializer AxisNames = common_types.AxisNames @@ -621,7 +621,7 @@ def out_projection(self, output_dim: int, out: Array) -> Array: def key_rotary(self, key: Array, inputs_positions: Array): """Apply Rotary Embedding to key.""" - key = LLaMARotaryEmbedding( + key = RotaryEmbedding( embedding_dims=self.head_dim, name='key_rotary')(inputs=key, position=inputs_positions) return key @@ -663,7 +663,7 @@ def __call__(self, value = self.kv_projection(inputs_kv, proj_name='value') # apply ROPE - query = LLaMARotaryEmbedding( + query = RotaryEmbedding( embedding_dims=self.head_dim, name='query_rotary' )(inputs=query, position=inputs_positions) key = self.key_rotary(key, inputs_positions) diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 48c9a579a..6fa1fafb4 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -102,8 +102,8 @@ def attend(self, query: Array) -> Array: return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) -class LLaMARotaryEmbedding(nn.Module): - """LLaMA variant of ROPE where inputs are split in a different way. +class RotaryEmbedding(nn.Module): + """RoPE Attributes: min_timescale: Start of the geometric index. Determines the periodicity of @@ -167,8 +167,9 @@ def __call__( sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) reshape_tensor = inputs.astype(jnp.float32).reshape( - *inputs.shape[:-1], -1, 2 + *inputs.shape[:-1], 2, -1 ) + reshape_tensor = jax.numpy.swapaxes(reshape_tensor, -1, -2) first_half = reshape_tensor[..., 0] second_half = reshape_tensor[..., 1] first_part = first_half * cos - second_half * sin diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index cf5c101f4..bd01fe0bc 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -43,9 +43,7 @@ ScanIn = common_types.ScanIn Embed = embeddings.Embed -LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding Attention = attentions.Attention -AttentionOp = attentions.AttentionOp RMSNorm = normalizations.RMSNorm diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 52bde475c..698f499d3 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -35,7 +35,6 @@ ScanIn = common_types.ScanIn Embed = embeddings.Embed -LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding Attention = attentions.Attention RMSNorm = normalizations.RMSNorm diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index f5eba46a8..d85676c3d 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -33,7 +33,6 @@ Mesh = jax.sharding.Mesh Attention = attentions.Attention -LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding class AttentionTest(unittest.TestCase): diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index 2c613103b..db7bca697 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -76,9 +76,14 @@ def apply_rotary_emb( return xq_out.astype(dtype), xk_out.astype(dtype) -class LlamaRoPETest(unittest.TestCase): +def permute_to_match_maxtext_rope(arr): + evens = arr[..., ::2] + odds = arr[..., 1::2] + return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + +class RoPETest(unittest.TestCase): """Test for the RoPE implementation """ - def test_llama_rope(self): + def test_rope(self): dim_per_head = 128 seq_len = 8 @@ -100,8 +105,8 @@ def test_llama_rope(self): position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings from MaxText implementation - query_proj = embeddings.LLaMARotaryEmbedding(embedding_dims = dim_per_head)(x_q, position = position) - key_proj = embeddings.LLaMARotaryEmbedding(embedding_dims = dim_per_head)(x_k, position = position) + query_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_q), position = position) + key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position) # Compare results self.assertTrue(jax.numpy.allclose(llama_output[0], query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) diff --git a/end_to_end/llama_load_and_test.sh b/end_to_end/llama_load_and_test.sh new file mode 100644 index 000000000..6377d8e5f --- /dev/null +++ b/end_to_end/llama_load_and_test.sh @@ -0,0 +1,6 @@ +set -e +idx=$(date +%Y-%m-%d-%H-%M) +pip install torch +gsutil cp -r gs://maxtext-llama/llama2-7b/meta-ckpt /tmp/ +python3 MaxText/convert_llama_ckpt.py --base-model-path /tmp/meta-ckpt --model-size 7b --maxtext-model-path gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext/ +python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_${idx} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset load_from_other_directory=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext/ load_from_other_directory_step=0 per_device_batch_size=1 model_name='llama2-7b' assets_path=gs://maxtext-llama/llama2-7b ici_tensor_parallelism=4 steps=1 max_prefill_predict_length=4 max_target_length=16 async_checkpointing=false prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product