-
Notifications
You must be signed in to change notification settings - Fork 1.3k
OLMo 2 #1897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OLMo 2 #1897
Changes from all commits
b62eed2
f559763
276a8fc
1ac888f
d3456e3
121f851
3d34921
15f549d
ac3509f
852ca3e
ff47a66
69adbd9
5ab1796
8da6edb
ad5724f
9e40a07
29d2a76
b73fac5
34d6c15
5fa318a
e1caecf
e8b43c8
8d8e327
487187f
786650c
c7bcfb2
b1bbe36
bc0a34e
d4179d1
e7a8052
0b00942
0bbd667
f627087
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -271,12 +271,16 @@ def __init__( | |
" (non-parallel residual and shared attention norm)." | ||
) | ||
|
||
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) | ||
self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps) | ||
self.attn = CausalSelfAttention(config, block_idx) | ||
self.post_attention_norm = ( | ||
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() | ||
) | ||
self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) | ||
self.norm_2 = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a less special-casey way of doing this could be to avoid the introduction of the boolean norm_1 and norm_2 configs, but rather just have Identity as the norm class itself There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The issue is that olmo2 selectively use RMSNorm for post_attention_norm and post_mlp_norm but Identity for norm_1 and norm_2 Perhaps a way to get rid of the booleans would be to specify it as a special case for olmo2:
IMO that's the easiest workaround to getting rid of norm_1 and nom_2 booleans. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about norm_1_class and norm_2_class as overrides to norm_class in the config file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good idea, it will be advantageous in the future, wdyt @ysjprojects? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can also do it in a follow up PR, doesn't need to be this one |
||
nn.Identity() | ||
if not config.norm_2 | ||
else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)) | ||
) | ||
self.mlp = config.mlp_class(config) | ||
self.post_mlp_norm = ( | ||
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() | ||
|
@@ -325,6 +329,7 @@ def forward( | |
else: | ||
x = attention_output + x | ||
x_normed = self.norm_2(x) | ||
|
||
return self.post_mlp_norm(self.mlp(x_normed)) + x | ||
|
||
|
||
|
@@ -346,8 +351,12 @@ def __init__(self, config: Config, block_idx: int) -> None: | |
self.apply_sliding_window_attention = config.sliding_window_indices[block_idx] | ||
|
||
if config.norm_qk: | ||
self.norm_q = config.norm_class(config.head_size, eps=config.norm_eps) | ||
self.norm_k = config.norm_class(config.head_size, eps=config.norm_eps) | ||
norm_q_size = config.n_head * config.head_size if config.norm_qk_type == "olmo2" else config.head_size | ||
norm_k_size = ( | ||
config.n_query_groups * config.head_size if config.norm_qk_type == "olmo2" else config.head_size | ||
) | ||
self.norm_q = config.norm_class(norm_q_size, eps=config.norm_eps) | ||
self.norm_k = config.norm_class(norm_k_size, eps=config.norm_eps) | ||
else: | ||
self.norm_q = self.norm_k = None | ||
|
||
|
@@ -387,6 +396,10 @@ def forward( | |
# Split qkv into query, key and value matrices. | ||
q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) | ||
|
||
if self.config.norm_qk and self.config.norm_qk_type == "olmo2": | ||
q = self.norm_q(q) | ||
k = self.norm_k(k) | ||
|
||
# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the | ||
# embedding size (C) into num_heads (nh) and head_size (hs). | ||
q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs) | ||
|
@@ -400,7 +413,7 @@ def forward( | |
k = k.transpose(1, 2) # (B, nh_k, T, hs) | ||
v = v.transpose(1, 2) # (B, nh_v, T, hs) | ||
|
||
if self.config.norm_qk: | ||
if self.config.norm_qk and self.config.norm_qk_type == "default": | ||
q = self.norm_q(q) | ||
k = self.norm_k(k) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.