diff --git a/litgpt/config.py b/litgpt/config.py index 133a9247a1..a4ef0bc21f 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -912,6 +912,58 @@ def norm_class(self) -> Type: configs.extend(olmo) +olmo2 = [ + # https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json + dict( + name="OLMo-2-1124-7B{}", + hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"), + vocab_size=100352, + padded_vocab_size=100352, + block_size=4096, + n_embd=4096, + n_layer=32, + n_head=32, + n_query_groups=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + norm_eps=1e-06, + intermediate_size=11008, + rope_base=500000, + norm_qk=True, + ), + # https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json + dict( + name="OLMo-2-1124-13B{}", + hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"), + vocab_size=100352, + padded_vocab_size=100352, + block_size=4096, + n_embd=5120, + n_layer=40, + n_head=40, + n_query_groups=40, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + norm_eps=1e-06, + intermediate_size=13824, + rope_base=500000, + nork_qk=True, + ), +] + +for c in olmo2: + for kind in ("", "-SFT", "-DPO", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + ############### # Google Gemma ############### diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 48850efd51..d815712cb4 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -368,6 +368,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama3() if re.search("Llama-3.*-Instruct-*", model_name): return Llama3() + if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name): + return Llama3() if re.search("FreeWilly2", model_name): return FreeWilly2() if re.search("Platypus", model_name): diff --git a/tests/test_model.py b/tests/test_model.py index e8a110a409..2ddaec0576 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -28,6 +28,7 @@ from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM +from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM import litgpt.config as config_module @@ -617,6 +618,64 @@ def test_against_olmo(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("OLMo-2-1124-7B", "OLMo-2-1124-13B")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_olmo2(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + intermediate_size=86, + ) + T = 5 + theirs_config = Olmo2Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + intermediate_size=ours_config.intermediate_size, + num_hidden_layers=ours_config.n_layer, + num_attention_heads=ours_config.n_head, + num_key_value_heads=ours_config.n_query_groups, + max_positional_embeddings=T, + attention_bias=ours_config.bias, + rope_theta=ours_config.rope_base, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = Olmo2ForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"),