Skip to content

OLMo 2 #1897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 4, 2025
Merged

OLMo 2 #1897

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b62eed2
OLMo 2: implemented core
ysjprojects Jan 4, 2025
f559763
minor fix
ysjprojects Jan 4, 2025
276a8fc
fix vocab size
ysjprojects Jan 4, 2025
1ac888f
fix test_model
ysjprojects Jan 4, 2025
d3456e3
custom conversion fn for olmo2 due to new q_norm and k_norm components
ysjprojects Jan 8, 2025
121f851
minor fix
ysjprojects Jan 8, 2025
3d34921
minor fix on test_model.py
ysjprojects Jan 8, 2025
15f549d
fix: post_feedforward_layernorm
ysjprojects Jan 8, 2025
ac3509f
minor fix
ysjprojects Jan 8, 2025
852ca3e
input_norm
ysjprojects Jan 8, 2025
ff47a66
Merge branch 'main' into olmo2
ysjprojects Feb 26, 2025
69adbd9
fixed olmo2
ysjprojects Feb 26, 2025
5ab1796
Merge branch 'main' into olmo2
Borda Mar 11, 2025
8da6edb
Merge branch 'main' into olmo2
Borda Mar 20, 2025
ad5724f
CUBLAS_WORKSPACE_CONFIG
Borda Mar 20, 2025
9e40a07
Merge branch 'main' into olmo2
ysjprojects Apr 1, 2025
29d2a76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2025
b73fac5
Merge branch 'main' into olmo2
Borda Apr 2, 2025
34d6c15
Merge branch 'main' into olmo2
Borda Apr 2, 2025
5fa318a
Merge branch 'main' into olmo2
Borda Apr 2, 2025
e1caecf
Merge branch 'main' into olmo2
Borda Apr 2, 2025
e8b43c8
Merge branch 'main' into olmo2
Borda Apr 3, 2025
8d8e327
Merge branch 'main' into olmo2
Borda Apr 3, 2025
487187f
Merge branch 'main' into olmo2
lantiga Apr 3, 2025
786650c
removed .gitignore redundant part
ysjprojects Apr 23, 2025
c7bcfb2
localize norm_q and norm_k to invocation strictly when norm_qk is True
ysjprojects Apr 23, 2025
b1bbe36
revert prev and removed redundant
ysjprojects Apr 23, 2025
bc0a34e
Merge branch 'main' into olmo2
Borda May 22, 2025
d4179d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2025
e7a8052
Merge branch 'main' into olmo2
ysjprojects Jun 3, 2025
0b00942
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2025
0bbd667
fixes to olmo2 q and k norm modules
Jun 3, 2025
f627087
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Config:
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
norm_qk: bool = False
norm_qk_type: Literal["default", "olmo2"] = "default"
post_attention_norm: bool = False
post_mlp_norm: bool = False
parallel_residual: bool = True
Expand Down Expand Up @@ -91,6 +92,8 @@ class Config:
scale_embeddings: bool = False
lm_head_bias: bool = False
final_logit_softcapping: Optional[float] = None
norm_1: bool = True
norm_2: bool = True
# The base period of the RoPE embeddings for local attention.
# If not provided, rope_theta will be used for both local and global attention.
rope_local_base_freq: Optional[float] = None
Expand Down Expand Up @@ -930,6 +933,68 @@ def norm_class(self) -> Type:

configs.extend(olmo)

olmo2 = [
# https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json
dict(
name="OLMo-2-1124-7B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"),
vocab_size=100278,
padded_vocab_size=100352,
block_size=4096,
n_embd=4096,
n_layer=32,
n_head=32,
n_query_groups=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=11008,
rope_base=500000,
norm_qk=True,
post_mlp_norm=True,
norm_1=False,
norm_2=False,
norm_qk_type="olmo2",
post_attention_norm=True,
),
# https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json
dict(
name="OLMo-2-1124-13B{}",
hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"),
vocab_size=100278,
padded_vocab_size=100352,
block_size=4096,
n_embd=5120,
n_layer=40,
n_head=40,
n_query_groups=40,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
norm_eps=1e-06,
intermediate_size=13824,
rope_base=500000,
norm_qk=True,
post_mlp_norm=True,
norm_1=False,
norm_2=False,
norm_qk_type="olmo2",
post_attention_norm=True,
),
]

for c in olmo2:
for kind in ("", "-SFT", "-DPO", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)

###############
# Google Gemma
###############
Expand Down
23 changes: 18 additions & 5 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ def __init__(
" (non-parallel residual and shared attention norm)."
)

self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config, block_idx)
self.post_attention_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
)
self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
self.norm_2 = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a less special-casey way of doing this could be to avoid the introduction of the boolean norm_1 and norm_2 configs, but rather just have Identity as the norm class itself

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps)
        self.attn = CausalSelfAttention(config, block_idx)
        self.post_attention_norm = (
            config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
        )
        self.norm_2 = (
            nn.Identity()
            if not config.norm_2
            else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps))
        )
        self.mlp = config.mlp_class(config)
        self.post_mlp_norm = (
            config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
        )

The issue is that olmo2 selectively use RMSNorm for post_attention_norm and post_mlp_norm but Identity for norm_1 and norm_2

Perhaps a way to get rid of the booleans would be to specify it as a special case for olmo2:

self.norm_1 = nn.Identity() if config.name.lower().startswith(("olmo-2-")) else...

IMO that's the easiest workaround to getting rid of norm_1 and nom_2 booleans.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about norm_1_class and norm_2_class as overrides to norm_class in the config file?
Then reading the config could set up norm_1_class, norm_2_class either from the config names or from norm_class.
This would move the cases from the model to the config. Ideally, we'd also subsume the shared_attention_norm, which seems to be very similar to not norm_2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea, it will be advantageous in the future, wdyt @ysjprojects?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also do it in a follow up PR, doesn't need to be this one

nn.Identity()
if not config.norm_2
else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps))
)
self.mlp = config.mlp_class(config)
self.post_mlp_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
Expand Down Expand Up @@ -325,6 +329,7 @@ def forward(
else:
x = attention_output + x
x_normed = self.norm_2(x)

return self.post_mlp_norm(self.mlp(x_normed)) + x


Expand All @@ -346,8 +351,12 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.apply_sliding_window_attention = config.sliding_window_indices[block_idx]

if config.norm_qk:
self.norm_q = config.norm_class(config.head_size, eps=config.norm_eps)
self.norm_k = config.norm_class(config.head_size, eps=config.norm_eps)
norm_q_size = config.n_head * config.head_size if config.norm_qk_type == "olmo2" else config.head_size
norm_k_size = (
config.n_query_groups * config.head_size if config.norm_qk_type == "olmo2" else config.head_size
)
self.norm_q = config.norm_class(norm_q_size, eps=config.norm_eps)
self.norm_k = config.norm_class(norm_k_size, eps=config.norm_eps)
else:
self.norm_q = self.norm_k = None

Expand Down Expand Up @@ -387,6 +396,10 @@ def forward(
# Split qkv into query, key and value matrices.
q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*)

if self.config.norm_qk and self.config.norm_qk_type == "olmo2":
q = self.norm_q(q)
k = self.norm_k(k)

# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
# embedding size (C) into num_heads (nh) and head_size (hs).
q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs)
Expand All @@ -400,7 +413,7 @@ def forward(
k = k.transpose(1, 2) # (B, nh_k, T, hs)
v = v.transpose(1, 2) # (B, nh_v, T, hs)

if self.config.norm_qk:
if self.config.norm_qk and self.config.norm_qk_type == "default":
q = self.norm_q(q)
k = self.norm_k(k)

Expand Down
2 changes: 2 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Llama3()
if re.search("Llama-3.*-Instruct-*", model_name):
return Llama3()
if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name):
return Llama3()
if re.search("R1", model_name):
return R1Base()
if re.search("FreeWilly2", model_name):
Expand Down
83 changes: 83 additions & 0 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,85 @@ def copy_weights_qwen_2_5(
pbar.update(progress_per_file)


def copy_weights_olmo2(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight",
"model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.post_attention_norm.bias",
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
"model.norm.weight": "transformer.ln_f.weight",
"model.norm.bias": "transformer.ln_f.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
}
)
else:
raise NotImplementedError

if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if progress_per_file is not None:
pbar.update(progress_per_file)

if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
pbar.update(progress_per_file)


def copy_weights_qwen_3(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
Expand Down Expand Up @@ -693,6 +772,10 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
elif model_name.lower().startswith("olmo-2-"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_olmo2, config, qkv_weights)
elif model_name.lower().startswith("qwen3"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down
60 changes: 60 additions & 0 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,64 @@ def copy_weights_qwen_2_5(
state_dict[to_name] = param


def copy_weights_olmo2(
config: Config,
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
untie_weights: bool = False,
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
"transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias",
"transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight",
"transformer.ln_f.weight": "model.norm.weight",
"transformer.ln_f.bias": "model.norm.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
}
)
else:
raise NotImplementedError

for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and untie_weights:
continue
name_template, *ids = layer_template(from_name, num_matches=2)
param = load_param(param, from_name, None)
if from_name.endswith(".attn.qkv.weight"):
to_names = (
"model.layers.{}.self_attn.q_proj.weight".format(*ids),
"model.layers.{}.self_attn.k_proj.weight".format(*ids),
"model.layers.{}.self_attn.v_proj.weight".format(*ids),
)
params = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
else:
to_names = (weight_map[name_template].format(*ids),)
params = (param,)

for to_name, param in zip(to_names, params):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def copy_weights_qwen_3(
config: Config,
state_dict: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -487,6 +545,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
copy_fn = partial(copy_weights_phi, config)
elif config.name.lower().startswith(("qwen2.5", "qwq")):
copy_fn = partial(copy_weights_qwen_2_5, config)
elif config.name.lower().startswith("olmo-2-"):
copy_fn = partial(copy_weights_olmo2, config)
elif config.name.lower().startswith("qwen3"):
copy_fn = partial(copy_weights_qwen_3, config)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
Expand Down
Loading