Skip to content

Commit

Permalink
feat: make alibi slopes configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Sep 27, 2024
1 parent 4e2b874 commit 20e5df2
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class MultiHeadSelfAttention(nn.Module):
"""Multi Head Self Attention Pytorch Layer."""
"""Multi Head Self Attention Pytorch Layer using flash attention, see https://github.com/Dao-AILab/flash-attention"""

def __init__(
self,
Expand All @@ -33,7 +33,32 @@ def __init__(
window_size: Optional[int] = None,
dropout_p: float = 0.0,
softcap: float = 0.0,
alibi_slopes: Tensor = None,
):
"""Initialize MultiHeadSelfAttention.
Parameters
----------
num_heads : int
number of heads
embed_dim : int
embedding dimension
bias : bool, optional
bias, by default False
is_causal : bool, optional
apply causal attention mask, by default False
window_size : Optional[int], optional
window_size, by default None
dropout_p : float, optional
dropout probability, by default 0.0
softcap : float, optional
Anything > 0 activates softcapping attention, by default 0.0
alibi_slopes : Tensor, optional
(nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j,
by default None
"""
super().__init__()

assert (
Expand All @@ -51,6 +76,10 @@ def __init__(
self.dropout_p = dropout_p
self.is_causal = is_causal
self.softcap = softcap
self.alibi_slopes = alibi_slopes

if alibi_slopes is not None:
assert alibi_slopes.shape[0] == num_heads

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)

Expand All @@ -59,6 +88,7 @@ def __init__(
def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:

query, key, value = self.lin_qkv(x).chunk(3, -1)

if model_comm_group:
Expand Down Expand Up @@ -86,7 +116,14 @@ def forward(
)

out = self.attention(
query, key, value, dropout_p=dropout_p, causal=False, window_size=self.window_size, softcap=self.softcap
query,
key,
value,
dropout_p=dropout_p,
causal=False,
window_size=self.window_size,
softcap=self.softcap,
alibi_slopes=self.alibi_slopes,
)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")

Expand Down

0 comments on commit 20e5df2

Please sign in to comment.