Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
190a1a73753a665087be14f6f82823d24608cce6 by Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>:

added support for GQA

--
9022b351c2d0e26fc1abbad228379b383cd3fa8a by Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>:

added GQA support for cudnn flash attention

COPYBARA_INTEGRATE_REVIEW=#555 from kocchop:cudnn_flash_dpa 9022b351c2d0e26fc1abbad228379b383cd3fa8a
PiperOrigin-RevId: 622915708
  • Loading branch information
kocchop authored and maxtext authors committed Apr 8, 2024
1 parent 7482d6a commit f04ba76
Showing 1 changed file with 32 additions and 51 deletions.
83 changes: 32 additions & 51 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def apply_attention(self,
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError("""Decode not supported with flash attention.
Use `dot_product` instead.""")
return self.cudnn_flash_attention(query, key, value), None, None
return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None
else:
raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.')

Expand Down Expand Up @@ -254,59 +254,40 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids):
x = wrap_flash_attention(query, key, value, decoder_segment_ids)
x = jnp.transpose(x, axes=(0, 2, 1, 3))
return x

def cudnn_flash_attention(self,
query: Array,
key: Array,
value: Array) -> Array:

def cudnn_flash_attention(
self,
query: Array,
key: Array,
value: Array,
decoder_segment_ids: Array | None,
model_mode: str = common_types.MODEL_MODE_TRAIN,
) -> Array:
"""CUDNN Flash Attention with Transformer Engine.
It is an unstable API. In future release, the API can get changed
A stable flash attention API will be included soon. Currently,
1. It does not support GQA, num_query_heads == num_kv_heads
2. It supports head_dim till 128
GQA support with head_dim=256 will be added soon
1. Stable API, supports GQA
2. Supports head_dim till 128; head_dim=256 support will be added soon
"""

# These imports are only meant to work in a GPU build.
import transformer_engine.jax.fused_attn as fused_attn # pytype: disable=import-error
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout # pytype: disable=import-error
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available # pytype: disable=import-error

batch, s_q, n_heads, head_dim = query.shape # pylint: disable=unused-variable
_, s_kv, _, _ = key.shape

qkv_layout = QKVLayout.BS3HD
attn_mask_type = AttnMaskType.CAUSAL_MASK
attn_bias_type = AttnBiasType.NO_BIAS

has_fused_attn_kernel = is_fused_attn_kernel_available(
self.dtype, self.dtype, qkv_layout,
attn_bias_type,
attn_mask_type,
self.dropout_rate, self.num_query_heads,
self.num_kv_heads, s_q,
s_kv, head_dim)

if not has_fused_attn_kernel:
raise ValueError("Flash attention is not supported for current config i.e. head_dim, seq_len, n_heads etc."
"Please see transformer_engine/common/fused_attn/fused_attn.cpp:NVTE_Fused_Attn_Backend for details")

q = jnp.reshape(query, (*query.shape[:2], 1, *query.shape[-2:]))
k = jnp.reshape(key, (*query.shape[:2], 1, *query.shape[-2:]))
v = jnp.reshape(value, (*query.shape[:2], 1, *query.shape[-2:]))
qkv = jnp.concatenate((q, k, v), axis=2) # to make it (b, s, 3, h, d)

return fused_attn.self_fused_attn(
qkv=qkv,
bias=None,
mask=jnp.zeros((batch, 1, s_q, s_kv)), # no padding
seed=None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=1.0/math.sqrt(head_dim),
dropout_probability=self.dropout_rate,
is_training=True)
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error

_, _, _, head_dim = query.shape # pylint: disable=unused-variable

#generate attn_mask
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)

dpa_layer = DotProductAttention(head_dim=head_dim,
num_attention_heads=self.num_query_heads,
num_gqa_groups=self.num_kv_heads,
attn_mask_type='causal', # 'causal' or 'padding'
attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=self.dropout_rate,
dropout_rng_name='aqt',
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0/math.sqrt(head_dim),
transpose_batch_sequence=False)
return dpa_layer(query, key, value, mask=attn_mask)

def compute_local_attention(self,
attn_weights: Array,
Expand Down

0 comments on commit f04ba76

Please sign in to comment.