Skip to content

Commit

Permalink
OLMo 2: implemented core
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects committed Jan 4, 2025
1 parent 40c08dc commit b62eed2
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
52 changes: 52 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,58 @@ def norm_class(self) -> Type:

configs.extend(olmo)

olmo2 = [
# https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json
dict(
name="OLMo-2-1124-7B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"),
vocab_size=100352,
padded_vocab_size=100352,
block_size=4096,
n_embd=4096,
n_layer=32,
n_head=32,
n_query_groups=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=11008,
rope_base=500000,
norm_qk=True,
),
# https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json
dict(
name="OLMo-2-1124-13B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"),
vocab_size=100352,
padded_vocab_size=100352,
block_size=4096,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=40,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=13824,
rope_base=500000,
nork_qk=True,
),
]

for c in olmo2:
for kind in ("", "-SFT", "-DPO", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)

###############
# Google Gemma
###############
Expand Down
2 changes: 2 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Llama3()
if re.search("Llama-3.*-Instruct-*", model_name):
return Llama3()
if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name):
return Llama3()
if re.search("FreeWilly2", model_name):
return FreeWilly2()
if re.search("Platypus", model_name):
Expand Down
59 changes: 59 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.olmo import OlmoConfig, OlmoForCausalLM
from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

import litgpt.config as config_module
Expand Down Expand Up @@ -617,6 +618,64 @@ def test_against_olmo(model_name, device, dtype):
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("OLMo-2-1124-7B", "OLMo-2-1124-13B"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_olmo2(model_name, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
intermediate_size=86,
)
T = 5
theirs_config = Olmo2Config(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_hidden_layers=ours_config.n_layer,
num_attention_heads=ours_config.n_head,
num_key_value_heads=ours_config.n_query_groups,
max_positional_embeddings=T,
attention_bias=ours_config.bias,
rope_theta=ours_config.rope_base,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = Olmo2ForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize(
("device", "dtype"),
Expand Down

0 comments on commit b62eed2

Please sign in to comment.