diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index 025eb2d..1d594ae 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -22,7 +22,7 @@ class MultiHeadSelfAttention(nn.Module): - """Multi Head Self Attention Pytorch Layer.""" + """Multi Head Self Attention Pytorch Layer using flash attention, see https://github.com/Dao-AILab/flash-attention""" def __init__( self, @@ -33,7 +33,32 @@ def __init__( window_size: Optional[int] = None, dropout_p: float = 0.0, softcap: float = 0.0, + alibi_slopes: Tensor = None, ): + """Initialize MultiHeadSelfAttention. + + Parameters + ---------- + num_heads : int + number of heads + embed_dim : int + embedding dimension + bias : bool, optional + bias, by default False + is_causal : bool, optional + apply causal attention mask, by default False + window_size : Optional[int], optional + window_size, by default None + dropout_p : float, optional + dropout probability, by default 0.0 + softcap : float, optional + Anything > 0 activates softcapping attention, by default 0.0 + alibi_slopes : Tensor, optional + (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j, + by default None + """ super().__init__() assert ( @@ -51,6 +76,10 @@ def __init__( self.dropout_p = dropout_p self.is_causal = is_causal self.softcap = softcap + self.alibi_slopes = alibi_slopes + + if alibi_slopes is not None: + assert alibi_slopes.shape[0] == num_heads self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) @@ -59,6 +88,7 @@ def __init__( def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: + query, key, value = self.lin_qkv(x).chunk(3, -1) if model_comm_group: @@ -86,7 +116,14 @@ def forward( ) out = self.attention( - query, key, value, dropout_p=dropout_p, causal=False, window_size=self.window_size, softcap=self.softcap + query, + key, + value, + dropout_p=dropout_p, + causal=False, + window_size=self.window_size, + softcap=self.softcap, + alibi_slopes=self.alibi_slopes, ) out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")