Skip to content

Commit

Permalink
fix rope tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jan 14, 2025
1 parent a3ecd0b commit 42a8cd1
Showing 1 changed file with 110 additions and 30 deletions.
140 changes: 110 additions & 30 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from dataclasses import dataclass

import torch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama
from transformers.models.llama.configuration_llama import LlamaConfig

from litgpt.model import apply_rope, build_rope_cache

Expand All @@ -17,7 +18,23 @@ def test_rope_gptneox():
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
position_ids = torch.arange(seq_len).unsqueeze(0)

theirs_rot_emb = GPTNeoXRotaryEmbedding(head_size, seq_len)
@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int

config = RoPEConfig(
dim=head_size,
max_position_embeddings=seq_len,
rope_theta=10_000,
hidden_size=head_size * n_head,
num_attention_heads=n_head
)

theirs_rot_emb = GPTNeoXRotaryEmbedding(config)
theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids)

ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device)
Expand All @@ -35,13 +52,32 @@ def test_rope_gptneox():
def test_rope_llama_2():
head_dim = 64
rope_theta = 10_000
num_heads = 4
batch_size, seq_len = 1, 10

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10

@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads
)

rot_emb = LlamaRotaryEmbedding(config)

qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All @@ -56,8 +92,6 @@ def test_rope_llama_2():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand All @@ -76,13 +110,33 @@ def test_rope_llama_2():
def test_rope_llama_3():
head_dim = 64
rope_theta = 50_000
num_heads = 4
batch_size, seq_len = 1, 10

##################################
# Compare cos and sin
##################################

@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int
scaling_factor: float

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads,
scaling_factor=None
)

# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10
rot_emb = LlamaRotaryEmbedding(config)
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All @@ -97,8 +151,6 @@ def test_rope_llama_3():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand All @@ -117,6 +169,8 @@ def test_rope_llama_3():
def test_rope_llama_3_1():
head_dim = 32
rope_theta = 50_000
num_heads = 4
batch_size, seq_len = 1, 131_072

their_rope_config = {
"factor": 8.0,
Expand All @@ -133,18 +187,32 @@ def test_rope_llama_3_1():
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 131_072

@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int
rope_type: str
rope_scaling: dict

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads,
rope_type="llama3",
rope_scaling=their_rope_config
)

rot_emb = LlamaRotaryEmbedding(config=config)
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All @@ -159,8 +227,6 @@ def test_rope_llama_3_1():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand All @@ -179,6 +245,8 @@ def test_rope_llama_3_1():
def test_rope_llama_3_2():
head_dim = 32
rope_theta = 50_000
batch_size, seq_len = 1, 131_072
num_heads = 4

their_rope_config = {
"factor": 32.0,
Expand All @@ -195,18 +263,32 @@ def test_rope_llama_3_2():
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 131_072
@dataclass
class RoPEConfig:
dim: int
max_position_embeddings: int
rope_theta: int
hidden_size: int
num_attention_heads: int
rope_type: str
rope_scaling: dict

config = RoPEConfig(
dim=head_dim,
max_position_embeddings=seq_len,
rope_theta=rope_theta,
hidden_size=head_dim * num_heads,
num_attention_heads=num_heads,
rope_type="llama3",
rope_scaling=their_rope_config
)

rot_emb = LlamaRotaryEmbedding(config)

qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All @@ -221,8 +303,6 @@ def test_rope_llama_3_2():
##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
Expand Down

0 comments on commit 42a8cd1

Please sign in to comment.