From 42a8cd1fe2bc0f5b61b211354686a27349a5887a Mon Sep 17 00:00:00 2001 From: rasbt Date: Tue, 14 Jan 2025 10:59:32 -0600 Subject: [PATCH] fix rope tests --- tests/test_rope.py | 140 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 110 insertions(+), 30 deletions(-) diff --git a/tests/test_rope.py b/tests/test_rope.py index 0aa10aeb58..bd28d1f012 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -1,11 +1,12 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from dataclasses import dataclass + import torch from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama -from transformers.models.llama.configuration_llama import LlamaConfig from litgpt.model import apply_rope, build_rope_cache @@ -17,7 +18,23 @@ def test_rope_gptneox(): x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float() position_ids = torch.arange(seq_len).unsqueeze(0) - theirs_rot_emb = GPTNeoXRotaryEmbedding(head_size, seq_len) + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + + config = RoPEConfig( + dim=head_size, + max_position_embeddings=seq_len, + rope_theta=10_000, + hidden_size=head_size * n_head, + num_attention_heads=n_head + ) + + theirs_rot_emb = GPTNeoXRotaryEmbedding(config) theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids) ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device) @@ -35,13 +52,32 @@ def test_rope_gptneox(): def test_rope_llama_2(): head_dim = 64 rope_theta = 10_000 + num_heads = 4 + batch_size, seq_len = 1, 10 ################################## # Compare cos and sin ################################## # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta) - batch_size, seq_len = 1, 10 + + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads + ) + + rot_emb = LlamaRotaryEmbedding(config) + qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -56,8 +92,6 @@ def test_rope_llama_2(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) @@ -76,13 +110,33 @@ def test_rope_llama_2(): def test_rope_llama_3(): head_dim = 64 rope_theta = 50_000 + num_heads = 4 + batch_size, seq_len = 1, 10 ################################## # Compare cos and sin ################################## + + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + scaling_factor: float + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads, + scaling_factor=None + ) + # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta) - batch_size, seq_len = 1, 10 + rot_emb = LlamaRotaryEmbedding(config) qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -97,8 +151,6 @@ def test_rope_llama_3(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) @@ -117,6 +169,8 @@ def test_rope_llama_3(): def test_rope_llama_3_1(): head_dim = 32 rope_theta = 50_000 + num_heads = 4 + batch_size, seq_len = 1, 131_072 their_rope_config = { "factor": 8.0, @@ -133,18 +187,32 @@ def test_rope_llama_3_1(): "original_max_seq_len": 8192 } - config = LlamaConfig( - rope_theta=rope_theta, - rope_scaling=their_rope_config, - head_dim=head_dim - ) - ################################## # Compare cos and sin ################################## # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") - batch_size, seq_len = 1, 131_072 + + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + rope_type: str + rope_scaling: dict + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads, + rope_type="llama3", + rope_scaling=their_rope_config + ) + + rot_emb = LlamaRotaryEmbedding(config=config) qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -159,8 +227,6 @@ def test_rope_llama_3_1(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123) @@ -179,6 +245,8 @@ def test_rope_llama_3_1(): def test_rope_llama_3_2(): head_dim = 32 rope_theta = 50_000 + batch_size, seq_len = 1, 131_072 + num_heads = 4 their_rope_config = { "factor": 32.0, @@ -195,18 +263,32 @@ def test_rope_llama_3_2(): "original_max_seq_len": 8192 } - config = LlamaConfig( - rope_theta=rope_theta, - rope_scaling=their_rope_config, - head_dim=head_dim - ) - ################################## # Compare cos and sin ################################## # transformer rope - rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") - batch_size, seq_len = 1, 131_072 + @dataclass + class RoPEConfig: + dim: int + max_position_embeddings: int + rope_theta: int + hidden_size: int + num_attention_heads: int + rope_type: str + rope_scaling: dict + + config = RoPEConfig( + dim=head_dim, + max_position_embeddings=seq_len, + rope_theta=rope_theta, + hidden_size=head_dim * num_heads, + num_attention_heads=num_heads, + rope_type="llama3", + rope_scaling=their_rope_config + ) + + rot_emb = LlamaRotaryEmbedding(config) + qk_tensor = torch.randn(batch_size, seq_len, head_dim) position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) @@ -221,8 +303,6 @@ def test_rope_llama_3_2(): ################################## # Compare rotated tensors ################################## - # Settings - num_heads = 4 # Dummy query and key tensors torch.manual_seed(123)