Skip to content

Commit

Permalink
Merge pull request #326 from google:rwitten_rope_swap
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597677549
  • Loading branch information
maxtex authors committed Jan 12, 2024
2 parents 2116d35 + 18e96c7 commit b947cdd
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
- name: Test llama2_decode.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset load_from_other_directory=gs://maxtext-llama/llama2-7b/decode-ckpt-maxtext load_from_other_directory_step=0 per_device_batch_size=1 model_name='llama2-7b' assets_path=gs://maxtext-llama/llama2-7b ici_tensor_parallelism=4 steps=1 max_prefill_predict_length=4 max_target_length=16 async_checkpointing=false prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share."'
'bash end_to_end/llama_load_and_test.sh'
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job
gpu:
Expand Down
18 changes: 8 additions & 10 deletions MaxText/convert_llama_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@

jax.config.update('jax_platform_name', 'cpu')

def permute_to_match_maxtext_rope(arr):
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1)

MODEL_PARAMS_DICT = {
'70b': {
'num_layers': 80,
Expand Down Expand Up @@ -172,6 +177,8 @@ def convert(base_model_path, maxtext_model_path, model_size):
wq = np.reshape(wq, [base_num_query_heads * head_dim, base_num_query_heads, head_dim])
wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim])
wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim])
wq = permute_to_match_maxtext_rope(wq)
wk = permute_to_match_maxtext_rope(wk)

w_post = np.concatenate(
[
Expand Down Expand Up @@ -223,10 +230,6 @@ def convert(base_model_path, maxtext_model_path, model_size):

jax_weights['decoder']['decoder']['self_attention'] = self_attention





layer_weight['mlp']['wi_0']['kernel'] = np.array(layer_weight['mlp']['wi_0']['kernel'])
layer_weight['mlp']['wi_1']['kernel'] = np.array(layer_weight['mlp']['wi_1']['kernel'])
layer_weight['mlp']['wo']['kernel'] = np.array(layer_weight['mlp']['wo']['kernel'])
Expand All @@ -250,8 +253,6 @@ def convert(base_model_path, maxtext_model_path, model_size):
#convert all weights to jax.numpy
jax_weights = jax.tree_map(jnp.array, jax_weights)

print(f"jax_weights = {jax_weights}")

#dummy configs for the checkpoint_manager
step_number_to_save_new_ckpt = 0
enable_checkpointing=True
Expand All @@ -272,10 +273,7 @@ def convert(base_model_path, maxtext_model_path, model_size):
params=jax_weights,
tx=None, # type: ignore
opt_state={}
)


print(f"Trainstate after replacing params with jax_weights={state_new}")
)

if checkpoint_manager is not None:
if checkpoint_manager.save(step_number_to_save_new_ckpt, state_new):
Expand Down
6 changes: 3 additions & 3 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
PRNGKey = common_types.PRNGKey

DenseGeneral = linears.DenseGeneral
LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding
RotaryEmbedding = embeddings.RotaryEmbedding
NdInitializer = initializers.NdInitializer

AxisNames = common_types.AxisNames
Expand Down Expand Up @@ -621,7 +621,7 @@ def out_projection(self, output_dim: int, out: Array) -> Array:

def key_rotary(self, key: Array, inputs_positions: Array):
"""Apply Rotary Embedding to key."""
key = LLaMARotaryEmbedding(
key = RotaryEmbedding(
embedding_dims=self.head_dim,
name='key_rotary')(inputs=key, position=inputs_positions)
return key
Expand Down Expand Up @@ -663,7 +663,7 @@ def __call__(self,
value = self.kv_projection(inputs_kv, proj_name='value')

# apply ROPE
query = LLaMARotaryEmbedding(
query = RotaryEmbedding(
embedding_dims=self.head_dim, name='query_rotary'
)(inputs=query, position=inputs_positions)
key = self.key_rotary(key, inputs_positions)
Expand Down
7 changes: 4 additions & 3 deletions MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def attend(self, query: Array) -> Array:
return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)


class LLaMARotaryEmbedding(nn.Module):
"""LLaMA variant of ROPE where inputs are split in a different way.
class RotaryEmbedding(nn.Module):
"""RoPE
Attributes:
min_timescale: Start of the geometric index. Determines the periodicity of
Expand Down Expand Up @@ -167,8 +167,9 @@ def __call__(
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
reshape_tensor = inputs.astype(jnp.float32).reshape(
*inputs.shape[:-1], -1, 2
*inputs.shape[:-1], 2, -1
)
reshape_tensor = jax.numpy.swapaxes(reshape_tensor, -1, -2)
first_half = reshape_tensor[..., 0]
second_half = reshape_tensor[..., 1]
first_part = first_half * cos - second_half * sin
Expand Down
2 changes: 0 additions & 2 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
ScanIn = common_types.ScanIn

Embed = embeddings.Embed
LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding
Attention = attentions.Attention
AttentionOp = attentions.AttentionOp
RMSNorm = normalizations.RMSNorm


Expand Down
1 change: 0 additions & 1 deletion MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ScanIn = common_types.ScanIn

Embed = embeddings.Embed
LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding
Attention = attentions.Attention
RMSNorm = normalizations.RMSNorm

Expand Down
1 change: 0 additions & 1 deletion MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

Mesh = jax.sharding.Mesh
Attention = attentions.Attention
LLaMARotaryEmbedding = embeddings.LLaMARotaryEmbedding


class AttentionTest(unittest.TestCase):
Expand Down
13 changes: 9 additions & 4 deletions MaxText/tests/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,14 @@ def apply_rotary_emb(

return xq_out.astype(dtype), xk_out.astype(dtype)

class LlamaRoPETest(unittest.TestCase):
def permute_to_match_maxtext_rope(arr):
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1)

class RoPETest(unittest.TestCase):
"""Test for the RoPE implementation """
def test_llama_rope(self):
def test_rope(self):
dim_per_head = 128
seq_len = 8

Expand All @@ -100,8 +105,8 @@ def test_llama_rope(self):
position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]

# Calculate RoPE embeddings from MaxText implementation
query_proj = embeddings.LLaMARotaryEmbedding(embedding_dims = dim_per_head)(x_q, position = position)
key_proj = embeddings.LLaMARotaryEmbedding(embedding_dims = dim_per_head)(x_k, position = position)
query_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_q), position = position)
key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position)

# Compare results
self.assertTrue(jax.numpy.allclose(llama_output[0], query_proj, rtol=1e-01, atol=1e-04, equal_nan=False))
Expand Down
6 changes: 6 additions & 0 deletions end_to_end/llama_load_and_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set -e
idx=$(date +%Y-%m-%d-%H-%M)
pip install torch
gsutil cp -r gs://maxtext-llama/llama2-7b/meta-ckpt /tmp/
python3 MaxText/convert_llama_ckpt.py --base-model-path /tmp/meta-ckpt --model-size 7b --maxtext-model-path gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext/
python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_${idx} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset load_from_other_directory=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext/ load_from_other_directory_step=0 per_device_batch_size=1 model_name='llama2-7b' assets_path=gs://maxtext-llama/llama2-7b ici_tensor_parallelism=4 steps=1 max_prefill_predict_length=4 max_target_length=16 async_checkpointing=false prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product

0 comments on commit b947cdd

Please sign in to comment.