Skip to content

Commit

Permalink
feat: get alibi_slopes
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Oct 2, 2024
1 parent 0eb5c50 commit c04e641
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#

import logging
import math
from typing import Optional

import einops
import torch
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
Expand Down Expand Up @@ -76,10 +78,11 @@ def __init__(
self.dropout_p = dropout_p
self.is_causal = is_causal
self.softcap = softcap
self.alibi_slopes = alibi_slopes
self.use_alibi_slopes = True # use_alibi_slopes

if alibi_slopes is not None:
assert alibi_slopes.shape[0] == num_heads
if self.use_alibi_slopes is not None:
self.alibi_slopes = get_alibi_slopes(num_heads)
assert self.alibi_slopes.shape[0] == num_heads

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)

Expand Down Expand Up @@ -126,6 +129,9 @@ def forward(
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)

alibi_slopes = self.alibi_slopes.repeat(query.shape[0]).to(query.device) if self.use_alibi_slopes else None

out = self.attention(
query,
key,
Expand All @@ -134,7 +140,7 @@ def forward(
window_size=self.window_size,
dropout_p=dropout_p,
softcap=self.softcap,
alibi_slopes=self.alibi_slopes,
alibi_slopes=alibi_slopes,
)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
Expand All @@ -152,3 +158,14 @@ def forward(
out = self.projection(out)

return out


def get_alibi_slopes(num_heads: int) -> Tensor:
n = 2 ** math.floor(math.log2(num_heads))
slope_0 = 2.0 ** (-8.0 / n)
alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n))
if n < num_heads:
slope_hat_0 = 2.0 ** (-4.0 / n)
alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat])
return alibi_slopes

0 comments on commit c04e641

Please sign in to comment.