diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index 0784e06..c803109 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -8,9 +8,11 @@ # import logging +import math from typing import Optional import einops +import torch from torch import Tensor from torch import nn from torch.distributed.distributed_c10d import ProcessGroup @@ -76,10 +78,11 @@ def __init__( self.dropout_p = dropout_p self.is_causal = is_causal self.softcap = softcap - self.alibi_slopes = alibi_slopes + self.use_alibi_slopes = True # use_alibi_slopes - if alibi_slopes is not None: - assert alibi_slopes.shape[0] == num_heads + if self.use_alibi_slopes is not None: + self.alibi_slopes = get_alibi_slopes(num_heads) + assert self.alibi_slopes.shape[0] == num_heads self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) @@ -126,6 +129,9 @@ def forward( query, key, value = ( einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) ) + + alibi_slopes = self.alibi_slopes.repeat(query.shape[0]).to(query.device) if self.use_alibi_slopes else None + out = self.attention( query, key, @@ -134,7 +140,7 @@ def forward( window_size=self.window_size, dropout_p=dropout_p, softcap=self.softcap, - alibi_slopes=self.alibi_slopes, + alibi_slopes=alibi_slopes, ) out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") else: @@ -152,3 +158,14 @@ def forward( out = self.projection(out) return out + + +def get_alibi_slopes(num_heads: int) -> Tensor: + n = 2 ** math.floor(math.log2(num_heads)) + slope_0 = 2.0 ** (-8.0 / n) + alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n)) + if n < num_heads: + slope_hat_0 = 2.0 ** (-4.0 / n) + alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) + alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat]) + return alibi_slopes