Skip to content

Commit

Permalink
Merge pull request #1086 from kocchop:faysal/add-swa-to-cudnn-flash-te
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709393275
  • Loading branch information
maxtext authors committed Dec 24, 2024
2 parents 77f6459 + d0f1d38 commit 9afb70f
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,30 +393,38 @@ def cudnn_flash_attention(
model_mode: str = common_types.MODEL_MODE_TRAIN,
) -> Array:
"""CUDNN Flash Attention with Transformer Engine.
1. Stable API, supports GQA
2. Supports head_dim till 128; head_dim=256 support will be added soon
1. Stable API, supports GQA, SWA (only with causal masking)
2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6
"""
# These imports are only meant to work in a GPU build.
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error

_, _, _, head_dim = query.shape # pylint: disable=unused-variable

# generate attn_mask
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
sliding_window_size = self.sliding_window_size
if self.attention_type == AttentionType.LOCAL_SLIDING:
sliding_window_size = [self.sliding_window_size, 0]
mask_type = "causal" # SWA only works with causal masking
attn_mask = None
else:
# generate attn_mask
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)

dpa_layer = DotProductAttention(
head_dim=head_dim,
num_attention_heads=self.num_query_heads,
num_gqa_groups=self.num_kv_heads,
attn_mask_type="padding_causal", # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=self.dropout_rate,
dropout_rng_name="aqt",
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0 / math.sqrt(head_dim),
transpose_batch_sequence=False,
window_size=sliding_window_size,
)
return dpa_layer(query, key, value, mask=attn_mask)

Expand Down

0 comments on commit 9afb70f

Please sign in to comment.