diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 77cd8d235a..189a3d7a6d 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -132,705 +132,622 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } } -// select a backend for fused attention -NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right) { - using namespace transformer_engine; - NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - const int device_id = cuda::current_device(); - const int sm_arch_ = cuda::sm_arch(device_id); - NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - auto cudnn_runtime_version = cudnnGetVersion(); - - // For ragged offsets we only support 32-bit prior to cuDNN 9.5 - // Only used when THD format is requested. - const bool requires_64bit_ragged_offset = - (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( - layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); - const bool supported_ragged_offset_size = - (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - - if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // 8.9: t3hd, max_s=512, d=64, padding - ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && - // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000)) { - if (cudnn_runtime_version >= 8900) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { - bool flag_m512 = false; - bool flag_arb = false; - if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && - (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && - ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - max_seqlen_q == max_seqlen_kv) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && - ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && - ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) { - flag_m512 = true; - } - if ( - // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // architecture - ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || - (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || - (cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) && - // sequence length - ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || - (cudnn_runtime_version >= 90000)) && - // number of heads - ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 8907)) && - // head dimension - // multiples of 8 - (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && - // <= 128 - ((head_dim_qk <= 128 && head_dim_v <= 128) || - // 9.1: <= 256 + Hopper + fprop - // 9.5: <= 256 + Hopper + bprop - (head_dim_qk <= 256 && head_dim_v <= 256 && - ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || - (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || - // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 - (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && - layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10.2: any head_dim + any arch + fprop + paged - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91002 && - (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || - (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || - // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged - (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && - cudnn_runtime_version >= 91100)) && - // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed - (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && - head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && - head_dim_qk != head_dim_v))) && - // bias type - ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (cudnn_runtime_version >= 8906 && - (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - (bias_type == NVTE_Bias_Type::NVTE_ALIBI && - attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - sm_arch_ >= 90) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - (cudnn_runtime_version >= 90000 && - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && - // mask type - // pre-8.9.6: causal - ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} - (cudnn_runtime_version >= 8906 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.1: adds thd + {padding, padding_causal} - (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90300 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right} - (cudnn_runtime_version >= 90500 && - layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90600 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} - // for any q_format/kv_format, and paged/non-paged - (cudnn_runtime_version >= 90700 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - ((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv)))) && - // bias + mask combination - (!(cudnn_runtime_version >= 8906 && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - // qkv format - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || - (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && - ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || - (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || - (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && - cudnn_runtime_version >= 90700)) && - // sliding window - // pre-9.2: full attn, causal - ((cudnn_runtime_version < 90200 && window_size_left == -1 && - (window_size_right == -1 || window_size_right == 0)) || - // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} - (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q == max_seqlen_kv)) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || - // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} - (cudnn_runtime_version >= 90600 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - // TODO(cyang): fix bug for BRCM + cross-attention on sm100 - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700)))) || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700))))) && - max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - dropout == 0.0)))) && - // check 64-bit ragged offset support - (supported_ragged_offset_size) && - // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && - // softmax type - // pre-9.13.1: vanilla - // 9.13.1+: vanilla, off-by-one, learnable - (cudnn_runtime_version >= 91301 || - (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { - flag_arb = true; - } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } - if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - if (flag_arb == true) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } else if ((flag_arb == false) && (flag_m512 == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; - } - int env_backend = static_cast(backend); - env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && - flag_m512) || - ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && - flag_arb)) { - backend = static_cast(env_backend); - } - } - if (cudnn_runtime_version < 8901 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." - << std::endl; +namespace { +struct BackendSelectionContext { + bool is_training; + NVTEDType q_dtype; + NVTE_QKV_Layout qkv_layout; + NVTE_Bias_Type bias_type; + NVTE_Mask_Type attn_mask_type; + NVTE_Softmax_Type softmax_type; + float dropout; + size_t num_attn_heads; + size_t num_gqa_groups; + size_t max_seqlen_q; + size_t max_seqlen_kv; + size_t head_dim_qk; + size_t head_dim_v; + int64_t window_size_left; + int64_t window_size_right; + + int sm_arch; + int cudnn_version; + NVTE_QKV_Format qkv_format; + NVTE_QKV_Format q_format; + NVTE_QKV_Format kv_format; + NVTE_QKV_Layout_Group layout_group; + bool requires_64bit_ragged_offset; + bool supported_ragged_offset_size; + + std::string error_msg; + + void set_error(const std::string &msg) { error_msg = msg; } +}; + +bool checks_for_fp8(BackendSelectionContext &ctx) { + // Check dtype + if (ctx.q_dtype != NVTEDType::kNVTEFloat8E4M3 && ctx.q_dtype != NVTEDType::kNVTEFloat8E5M2) { + ctx.set_error("FP8 backend requires FP8E4M3 or FP8E5M2 dtype"); + return false; + } + + // Check architecture + if (ctx.sm_arch < 90) { + ctx.set_error("FP8 backend requires SM90 (Hopper) or newer, got SM" + + std::to_string(ctx.sm_arch)); + return false; + } + + // Check bias + if (ctx.bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { + ctx.set_error("FP8 backend requires NVTE_NO_BIAS"); + return false; + } + + bool version_has_support = false; + // cuDNN 8.9: t3hd, max_s=512, d=64, padding + if (ctx.cudnn_version >= 8900 && ctx.sm_arch < 100 && + ctx.qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && ctx.max_seqlen_q == ctx.max_seqlen_kv && + ctx.max_seqlen_q <= 512 && ctx.head_dim_qk == 64 && ctx.head_dim_v == 64 && + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) { + version_has_support = true; + } + // cuDNN 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + if (ctx.cudnn_version >= 90201 && ctx.sm_arch < 100 && ctx.max_seqlen_q % 128 == 0 && + ctx.max_seqlen_kv % 128 == 0 && ctx.head_dim_qk == 128 && ctx.head_dim_v == 128 && + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) { + version_has_support = true; + } + // cuDNN 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + if (ctx.cudnn_version >= 90700) { + bool head_dim_ok = false; + if (ctx.sm_arch < 100 && !ctx.is_training && ctx.head_dim_qk <= 256 && ctx.head_dim_v <= 256) { + head_dim_ok = true; + } else if (ctx.sm_arch < 100 && ctx.is_training && ctx.head_dim_qk == 128 && + ctx.head_dim_v == 128) { + head_dim_ok = true; + } else if (ctx.sm_arch >= 100 && ctx.head_dim_qk <= 128 && ctx.head_dim_v <= 128) { + head_dim_ok = true; } - if (cudnn_runtime_version < 8900 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; + if (head_dim_ok && ctx.head_dim_qk % 16 == 0 && ctx.head_dim_v % 16 == 0 && + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) { + version_has_support = true; } - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - return backend; + if (!version_has_support) { + ctx.set_error("FP8 backend: cuDNN version" + std::to_string(ctx.cudnn_version) + + " does not support provided head_dim, seqlen, or mask"); + return false; + } + + // Check common constraints + if (ctx.qkv_format != NVTE_QKV_Format::NVTE_BSHD && + ctx.qkv_format != NVTE_QKV_Format::NVTE_SBHD) { + ctx.set_error("FP8 backend requires BSHD or SBHD format"); + return false; + } + if (ctx.requires_64bit_ragged_offset) { + ctx.set_error("FP8 backend does not support 64-bit ragged offsets"); + return false; + } + if (ctx.softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) { + ctx.set_error("FP8 backend requires vanilla softmax"); + return false; + } + if (ctx.cudnn_version == 91000) { + ctx.set_error("FP8 backend has known bugs in cuDNN 9.10.0"); + return false; + } + + return true; } -// NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, float attn_scale, - float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); - using namespace transformer_engine; +bool checks_for_max512(BackendSelectionContext &ctx) { + // Check dtype + if (ctx.q_dtype != NVTEDType::kNVTEFloat16 && ctx.q_dtype != NVTEDType::kNVTEBFloat16) { + ctx.set_error("Max512 backend requires FP16 or BF16 dtype"); + return false; + } - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); + // Check architecture + if (ctx.sm_arch != 80 && ctx.sm_arch != 90) { + ctx.set_error("Max512 backend requires sm80 or sm90, got sm" + std::to_string(ctx.sm_arch)); + return false; + } - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + // Check sequence length + if (ctx.max_seqlen_q > 512 || ctx.max_seqlen_kv > 512) { + ctx.set_error("Max512 backend requires seqlen <= 512, got q=" + + std::to_string(ctx.max_seqlen_q) + ", kv=" + std::to_string(ctx.max_seqlen_kv)); + return false; } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; + if (ctx.max_seqlen_q % 64 != 0 || ctx.max_seqlen_kv % 64 != 0) { + ctx.set_error("Max512 backend requires seqlen % 64 == 0"); + return false; } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + // Check head dimension + if (ctx.head_dim_qk != 64 || ctx.head_dim_v != 64) { + ctx.set_error("Max512 backend requires head_dim=64"); + return false; + } - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + // Check GQA + if (ctx.num_attn_heads != ctx.num_gqa_groups) { + ctx.set_error("Max512 backend does not support GQA"); + return false; + } - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias, - output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, - wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_QKV, input_output_S, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, - stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + // Check bias type + if (ctx.bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && + ctx.bias_type != NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { + ctx.set_error("Max512 backend requires NO_BIAS or POST_SCALE_BIAS"); + return false; } -} -// NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); - const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); - const Tensor *input_QKV = convertNVTETensorCheck(QKV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQKV = convertNVTETensorCheck(dQKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); + // Check mask type + bool mask_ok = false; + if (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) { + mask_ok = true; + } else if (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + ctx.max_seqlen_q == ctx.max_seqlen_kv) { + mask_ok = true; + } + if (!mask_ok) { + ctx.set_error("Max512 backend: unsupported mask type"); + return false; + } - auto ndim = input_QKV->data.shape.size(); - size_t b = input_cu_seqlens->data.shape[0] - 1; - size_t h = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - h = input_QKV->data.shape[ndim - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - h = input_QKV->data.shape[ndim - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + // Check layout + if (ctx.qkv_layout != NVTE_QKV_Layout::NVTE_SB3HD && + ctx.qkv_layout != NVTE_QKV_Layout::NVTE_SBHD_SB2HD && + ctx.qkv_layout != NVTE_QKV_Layout::NVTE_BS3HD && + ctx.qkv_layout != NVTE_QKV_Layout::NVTE_BSHD_BS2HD && + ctx.qkv_layout != NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + ctx.set_error("Max512 backend: unsupported QKV layout"); + return false; } - size_t d = input_QKV->data.shape[ndim - 1]; - size_t t = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - t = input_QKV->data.shape[0]; + + // Check window size + if (ctx.window_size_left != -1 || (ctx.window_size_right != -1 && ctx.window_size_right != 0)) { + ctx.set_error("Max512 backend requires does not support sliding window"); + return false; } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); + // Check ragged offset + if (!ctx.supported_ragged_offset_size) { + ctx.set_error("Max512 backend does not support 64-bit ragged offsets"); + return false; + } - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + // Check softmax type + if (ctx.softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) { + ctx.set_error("Max512 backend requires vanilla softmax type"); + return false; + } - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, - input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8900) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + return true; +} + +bool checks_for_arbitrary_seqlen(BackendSelectionContext &ctx) { + // Check dtype + if (ctx.q_dtype != NVTEDType::kNVTEFloat16 && ctx.q_dtype != NVTEDType::kNVTEBFloat16) { + ctx.set_error("ArbitrarySeqlen backend requires FP16 or BF16 dtype"); + return false; + } + + // Check architecture + bool arch_ok = false; + if (ctx.cudnn_version < 8903 && (ctx.sm_arch == 80 || ctx.sm_arch == 90)) { + arch_ok = true; + } else if (ctx.cudnn_version >= 8903 && ctx.sm_arch >= 80 && ctx.sm_arch < 100) { + arch_ok = true; + } else if (ctx.cudnn_version >= 90700 && ctx.sm_arch >= 80) { + arch_ok = true; + } + if (!arch_ok) { + ctx.set_error("ArbitrarySeqlen backend: unsupported sm" + std::to_string(ctx.sm_arch) + + " with cuDNN " + std::to_string(ctx.cudnn_version)); + return false; + } + + // Check sequence length + if (ctx.cudnn_version < 90000) { + if (ctx.max_seqlen_q % 64 != 0 || ctx.max_seqlen_kv % 64 != 0) { + ctx.set_error("ArbitrarySeqlen backend (cuDNN < 9.0) requires seqlen % 64 == 0"); + return false; } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + + // Check GQA + if (ctx.cudnn_version < 8907) { + if (ctx.num_attn_heads != ctx.num_gqa_groups) { + ctx.set_error("ArbitrarySeqlen backend (cuDNN < 8.9.7) does not support GQA"); + return false; } - fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, - input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); -#else - const char *err_msg = - "cuDNN 8.9.0 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, - input_S, input_output_dP, output_dQKV, input_cu_seqlens, - input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } -} -// NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); - using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + // Check head dimension + if (ctx.head_dim_qk % 8 != 0 || ctx.head_dim_v % 8 != 0) { + ctx.set_error("ArbitrarySeqlen backend requires head_dim % 8 == 0"); + return false; } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; + bool head_dim_ok = false; + // <= 128 + if (ctx.head_dim_qk <= 128 && ctx.head_dim_v <= 128) { + head_dim_ok = true; + // 9.1: <= 256 + Hopper + fprop + } else if (ctx.head_dim_qk <= 256 && ctx.head_dim_v <= 256 && !ctx.is_training && + ctx.sm_arch == 90 && ctx.cudnn_version >= 90100) { + head_dim_ok = true; + // 9.5: <= 256 + Hopper + bprop + } else if (ctx.head_dim_qk <= 256 && ctx.head_dim_v <= 256 && ctx.is_training && + ctx.sm_arch == 90 && ctx.cudnn_version >= 90500) { + head_dim_ok = true; + // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 + } else if (!ctx.is_training && ctx.sm_arch >= 100 && ctx.cudnn_version >= 90900 && + ctx.max_seqlen_q > 1 && + ctx.layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + head_dim_ok = true; + // 9.10.2: any head_dim + any arch + fprop + paged + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + } else if (!ctx.is_training && ctx.cudnn_version >= 91002 && + (ctx.layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || + ctx.max_seqlen_q > 1 || + (ctx.max_seqlen_q == 1 && ctx.attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && + ctx.attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) { + head_dim_ok = true; + // 9.11: d_qk=192, d_v=128 + Blackwell + bprop + non-paged + } else if (ctx.head_dim_qk == 192 && ctx.head_dim_v == 128 && ctx.is_training && + ctx.sm_arch >= 100 && ctx.cudnn_version >= 91100) { + head_dim_ok = true; } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; + // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + if (ctx.cudnn_version >= 91100 && ctx.is_training && ctx.sm_arch == 90 && + ctx.head_dim_qk >= 128 && ctx.head_dim_v >= 128 && + !(ctx.head_dim_qk == 192 && ctx.head_dim_v == 128) && ctx.head_dim_qk != ctx.head_dim_v) { + ctx.set_error("ArbitrarySeqlen backend: known cuDNN 9.11+ bug for sm90 bprop with MLA"); + return false; } - int64_t num_pages_k = 0; - int64_t num_pages_v = 0; - int64_t page_size_k = 0; - int64_t page_size_v = 0; - int64_t max_pages_per_seq_k = 0; - int64_t max_pages_per_seq_v = 0; - if (input_page_table_k->data.dptr != nullptr) { - max_pages_per_seq_k = input_page_table_k->data.shape[1]; + if (!head_dim_ok) { + ctx.set_error("ArbitrarySeqlen backend: unsupported head_dim (qk=" + + std::to_string(ctx.head_dim_qk) + ", v=" + std::to_string(ctx.head_dim_v) + ")"); + return false; } - if (input_page_table_v->data.dptr != nullptr) { - max_pages_per_seq_v = input_page_table_v->data.shape[1]; + + // Check bias type + bool bias_ok = false; + if (ctx.cudnn_version < 8906 && ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) { + bias_ok = true; + } else if (ctx.cudnn_version >= 8906) { + if (ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) { + bias_ok = true; + } else if (ctx.bias_type == NVTE_Bias_Type::NVTE_ALIBI && + ctx.attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && + ctx.attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + ctx.attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + ctx.attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + ctx.sm_arch >= 90) { + bias_ok = true; + } else if (ctx.bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && ctx.sm_arch >= 90) { + bias_ok = true; + } } - if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { - num_pages_k = input_KV->data.shape[0]; - page_size_k = input_KV->data.shape[1]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; - } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { - num_pages_k = input_KV->data.shape[1]; - page_size_k = input_KV->data.shape[0]; - num_pages_v = num_pages_v; - page_size_v = page_size_v; + if (ctx.cudnn_version >= 90000 && ctx.bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && + ctx.sm_arch >= 80) { + bias_ok = true; + } + if (!bias_ok) { + ctx.set_error("ArbitrarySeqlen backend: unsupported bias type"); + return false; + } + + // Check mask type + bool mask_ok = false; + // Pre-8.9.6: causal + if (ctx.cudnn_version < 8906 && ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) { + mask_ok = true; + // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} + } else if (ctx.cudnn_version >= 8906 && + (ctx.qkv_format == NVTE_QKV_Format::NVTE_SBHD || + ctx.qkv_format == NVTE_QKV_Format::NVTE_BSHD) && + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) { + mask_ok = true; + // 9.1: adds thd + {padding, padding_causal} + } else if (ctx.cudnn_version >= 90100 && ctx.qkv_format == NVTE_QKV_Format::NVTE_THD && + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) { + mask_ok = true; + // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) + } else if (ctx.cudnn_version >= 90300 && + (ctx.qkv_format == NVTE_QKV_Format::NVTE_SBHD || + ctx.qkv_format == NVTE_QKV_Format::NVTE_BSHD) && + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + ctx.max_seqlen_q % 64 == 0 && ctx.max_seqlen_kv % 64 == 0 && + ctx.max_seqlen_q <= ctx.max_seqlen_kv && + ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && ctx.dropout == 0.0) { + mask_ok = true; + // 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right} + } else if (ctx.cudnn_version >= 90500 && + ctx.layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + ctx.max_seqlen_q % 64 == 0 && ctx.max_seqlen_kv % 64 == 0 && + ctx.max_seqlen_q <= ctx.max_seqlen_kv)) && + ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && ctx.dropout == 0.0) { + mask_ok = true; + // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) + } else if (ctx.cudnn_version >= 90600 && + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + ctx.max_seqlen_q % 64 == 0 && ctx.max_seqlen_kv % 64 == 0 && + ctx.max_seqlen_q <= ctx.max_seqlen_kv && + ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && ctx.dropout == 0.0) { + mask_ok = true; + // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} + // for any q_format/kv_format, and paged/non-paged + } else if (ctx.cudnn_version >= 90700) { + if (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) { + mask_ok = true; + } else if ((ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && ctx.dropout == 0.0) { + mask_ok = true; + } else if ((ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + ctx.max_seqlen_q <= ctx.max_seqlen_kv) { + mask_ok = true; } } + if (!mask_ok) { + ctx.set_error("ArbitrarySeqlen backend: unsupported mask type for cuDNN" + + std::to_string(ctx.cudnn_version)); + return false; + } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); + // Check bias + mask combination + if (ctx.cudnn_version >= 8906 && + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && + ctx.bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { + ctx.set_error("ArbitrarySeqlen backend: POST_SCALE_BIAS not supported with PADDING masks"); + return false; + } - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + // Check QKV format + bool format_ok = false; + if (ctx.qkv_format == NVTE_QKV_Format::NVTE_SBHD || + ctx.qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + format_ok = true; + } else if (ctx.qkv_format == NVTE_QKV_Format::NVTE_THD && ctx.sm_arch >= 90 && + ((ctx.cudnn_version >= 90100 && ctx.num_attn_heads == ctx.num_gqa_groups) || + ctx.cudnn_version >= 90600)) { + format_ok = true; + } else if (ctx.cudnn_version >= 90700 && + ((ctx.q_format == NVTE_QKV_Format::NVTE_SBHD || + ctx.q_format == NVTE_QKV_Format::NVTE_BSHD || + (ctx.q_format == NVTE_QKV_Format::NVTE_THD && ctx.sm_arch >= 90)) || + (ctx.kv_format == NVTE_QKV_Format::NVTE_SBHD || + ctx.kv_format == NVTE_QKV_Format::NVTE_BSHD || + (ctx.kv_format == NVTE_QKV_Format::NVTE_THD && ctx.sm_arch >= 90)))) { + format_ok = true; + } + if (!format_ok) { + ctx.set_error("ArbitrarySeqlen backend: unsupported QKV format"); + return false; + } - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8903) - fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); -#else - NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif + // Check sliding window + bool window_ok = false; + // Pre-9.2: full attention, causal + if (ctx.cudnn_version < 90200 && ctx.window_size_left == -1 && + (ctx.window_size_right == -1 || ctx.window_size_right == 0)) { + window_ok = true; + // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} + } else if (ctx.cudnn_version >= 90200) { + if (ctx.window_size_left == -1 && (ctx.window_size_right == -1 || ctx.window_size_right == 0)) { + window_ok = true; + } else if ((ctx.window_size_left >= 0 || ctx.window_size_left == -1) && + ctx.window_size_right == 0 && + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + ctx.max_seqlen_q == ctx.max_seqlen_kv)) && + ctx.max_seqlen_q <= ctx.max_seqlen_kv && ctx.dropout == 0.0 && + ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + (ctx.qkv_format == NVTE_QKV_Format::NVTE_BSHD || + ctx.qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + window_ok = true; + } + } + // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} + if (ctx.cudnn_version >= 90600) { + if (ctx.window_size_left == -1 && (ctx.window_size_right == -1 || ctx.window_size_right == 0)) { + window_ok = true; + } else if ((ctx.window_size_left >= 0 || ctx.window_size_left == -1) && + ctx.window_size_right == 0) { + bool mask_and_arch_ok = false; + + if (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + (ctx.sm_arch < 100 || (ctx.sm_arch >= 100 && ((ctx.max_seqlen_q == ctx.max_seqlen_kv && + ctx.cudnn_version <= 90700) || + ctx.cudnn_version > 90700)))) { + mask_and_arch_ok = true; + } else if (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) { + mask_and_arch_ok = true; + } else if (ctx.attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + (ctx.sm_arch < 100 || + (ctx.sm_arch >= 100 && + ((ctx.max_seqlen_q == ctx.max_seqlen_kv && ctx.cudnn_version <= 90700) || + ctx.cudnn_version > 90700)))) { + mask_and_arch_ok = true; + } + if (mask_and_arch_ok && ctx.max_seqlen_q <= ctx.max_seqlen_kv && + ctx.bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && ctx.dropout == 0.0) { + window_ok = true; + } + } + } + if (!window_ok) { + ctx.set_error("ArbitrarySeqlen backend: unsupported sliding window configuration"); + return false; + } + + // Check ragged offset + if (!ctx.supported_ragged_offset_size) { + ctx.set_error("ArbitrarySeqlen backend does not support 64-bit ragged offset"); + return false; + } + + // Check known bugs + if (ctx.cudnn_version == 91000 || ctx.cudnn_version == 91001) { + ctx.set_error("ArbitrarySeqlen backend: known bugs with SDPA F16 in cuDNN 9.10.0/9.10.1"); + return false; + } + + // Check softmax type + if (ctx.cudnn_version >= 91301) { + // 9.13.1+: vanilla, off-by-one, learnable } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + // pre-9.13.1: vanilla + if (ctx.softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) { + ctx.set_error("ArbitrarySeqlen backend (cuDNN < 9.13.1) requires vanilla softmax type"); + return false; + } } + + return true; } -// NVTE fused attention BWD with packed KV -void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); +} // namespace + +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_KV = convertNVTETensorCheck(KV); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dKV = convertNVTETensorCheck(dKV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); + NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); - size_t b = input_cu_seqlens_q->data.shape[0] - 1; - auto ndim = input_Q->data.shape.size(); - size_t h_q = input_Q->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; - auto ndim_kv = input_KV->data.shape.size(); - size_t h_kv = 0; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - h_kv = input_KV->data.shape[ndim_kv - 2]; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - h_kv = input_KV->data.shape[ndim_kv - 3]; - } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); - } - size_t t_q = 0; - size_t t_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD) { - t_q = input_Q->data.shape[0]; - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - t_kv = input_KV->data.shape[0]; + BackendSelectionContext ctx; + ctx.is_training = is_training; + ctx.q_dtype = q_dtype; + ctx.qkv_layout = qkv_layout; + ctx.bias_type = bias_type; + ctx.attn_mask_type = attn_mask_type; + ctx.softmax_type = softmax_type; + ctx.dropout = dropout; + ctx.num_attn_heads = num_attn_heads; + ctx.num_gqa_groups = num_gqa_groups; + ctx.max_seqlen_q = max_seqlen_q; + ctx.max_seqlen_kv = max_seqlen_kv; + ctx.head_dim_qk = head_dim_qk; + ctx.head_dim_v = head_dim_v; + ctx.window_size_left = window_size_left; + ctx.window_size_right = window_size_right; + + const int device_id = cuda::current_device(); + ctx.sm_arch = cuda::sm_arch(device_id); + ctx.cudnn_version = cudnnGetVersion(); + ctx.qkv_format = nvte_get_qkv_format(qkv_layout); + ctx.q_format = nvte_get_q_format(qkv_layout); + ctx.kv_format = nvte_get_kv_format(qkv_layout); + ctx.layout_group = nvte_get_qkv_layout_group(qkv_layout); + ctx.requires_64bit_ragged_offset = + (ctx.qkv_format == NVTE_THD && + fused_attn::get_ragged_offset_dtype(ctx.layout_group, num_attn_heads, num_gqa_groups, + max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v) == DType::kInt64); + ctx.supported_ragged_offset_size = + (!ctx.requires_64bit_ragged_offset || ctx.cudnn_version >= 90500); + + // Try FP8 backend + if (checks_for_fp8(ctx)) { + if (ctx.cudnn_version >= 8900) { + return NVTE_Fused_Attn_Backend::NVTE_FP8; + } else { + std::cout << "Warning: FP8 fused attention requires cuDNN 8.9.0+. " + << "Please upgrade your cuDNN version." << std::endl; + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } } - auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); - const NVTEDType Q_type = static_cast(input_Q->data.dtype); - const NVTEDType KV_type = static_cast(input_KV->data.dtype); + // Try F16/BF16 backends + if (q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16) { + bool can_use_max512 = checks_for_max512(ctx); + std::string max512_error = ctx.error_msg; - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + bool can_use_arbitrary = checks_for_arbitrary_seqlen(ctx); + std::string arbitrary_error = ctx.error_msg; - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { -#if (CUDNN_VERSION >= 8901) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { -#if (CUDNN_VERSION >= 8903) - size_t i = 0; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - Tensor *input_Bias, *input_SoftmaxOffset; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + // Select backend based on seqlen and availability + NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + + if (max_seqlen_q > 512 || max_seqlen_kv > 512) { + // Must use arbitrary + if (can_use_arbitrary) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } else { + std::cout << "Warning: No fused attention backend available. " << arbitrary_error + << std::endl; + } + } else { + // seqlen <= 512: prefer arbitrary, fallback to max512 + if (can_use_arbitrary) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + } else if (can_use_max512) { + backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; + } else { + std::cout << "Warning: No fused attention backend available." << std::endl; + std::cout << " Max512: " << max512_error << std::endl; + std::cout << " Arbitrary: " << arbitrary_error << std::endl; + } + + // Environment variable override + int env_backend = static_cast(backend); + env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); + + if ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) && + can_use_max512) || + (env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) && + can_use_arbitrary)) { + backend = static_cast(env_backend); + } } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + + // Validate cuDNN version for selected backend + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen && ctx.cudnn_version < 8901) { + std::cout << "Warning: FP16/BF16 fused attention (max512) requires cuDNN 8.9.1+. " + << "Please upgrade your cuDNN version." << std::endl; + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, - output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); -#else - const char *err_msg = - "cuDNN 8.9.3 is required for BF16/FP16 fused attention " - "with arbitrary sequence length. \n"; - NVTE_ERROR(err_msg); -#endif - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { -#if (CUDNN_VERSION >= 8900) - const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); -#else - NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); -#endif - } else { - NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); + + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && ctx.cudnn_version < 8900) { + std::cout << "Warning: FP16/BF16 fused attention (arbitrary) requires cuDNN 8.9.0+. " + << "Please upgrade your cuDNN version." << std::endl; + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + return backend; } + + // No backend available + std::cout << "Warning: No fused attention backend available for the given configuration." + << std::endl; + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } + // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index ba0f845789..cec0bfda27 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -999,473 +999,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, - nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, - devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index b9658b0530..f22b11044c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -18,51 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 89528fa3c4..1028df6452 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_max_512_fwd_qkvpacked( - size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - const auto stride = 2 * num_head * head_dim; - - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = static_cast(input_Bias->data.dptr); - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; - output_S->data.dtype = input_QKV->data.dtype; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrCuSeqlen = cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - const DType QKV_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, - "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); - - // Q shape is [b, s, h, d] - void *devPtrQ = input_Q->data.dptr; - - // KV shape is [b, s, 2, h, d] - const auto stride = 2 * num_head * head_dim; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, } } -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - - auto stride = 2 * num_head * head_dim; - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQKV shape is [b, s, 3, h, d] - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - - const auto qkv_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK, - devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS, - devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(qkv_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // Q shape is [b, s, h, d] - // KV shape is [b, s, 2, h, d] - auto stride = 2 * num_head * head_dim; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQ shape is [b, s, h, d] - // dKV shape is [b, s, 2, h, d] - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dKV->data.dptr; - void *devPtrdV = static_cast(static_cast(devPtrdK) + stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 171fe846ce..57b7afcf43 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -18,25 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8901) -void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_size, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 21c544491a..3c50e7ab89 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2405,424 +2405,6 @@ void fused_attn_fp8_bwd_impl_v1( } // namespace fused_attn #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = static_cast(devPtrQKV); - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, - devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQKV->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = devPtrQKV; - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQKV = output_dQKV->data.dptr; - void* devPtrdQ = devPtrdQKV; - void* devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void* devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - void* devPtrAmaxdQ = output_dQKV->amax.dptr; - void* devPtrAmaxdK = output_dQKV->amax.dptr; - void* devPtrAmaxdV = output_dQKV->amax.dptr; - void* devPtrScaledQ = output_dQKV->scale.dptr; - void* devPtrScaledK = output_dQKV->scale.dptr; - void* devPtrScaledV = output_dQKV->scale.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, - devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, - devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, - devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, - Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQ->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQ = output_dQ->data.dptr; - void* devPtrdKV = output_dKV->data.dptr; - void* devPtrdK = devPtrdKV; - void* devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dKV->amax.dptr; - void* devPtrAmaxdV = output_dKV->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dKV->scale.dptr; - void* devPtrScaledV = output_dKV->scale.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 3daf45d162..c2efa25829 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -13,47 +13,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index a150978c4a..c5c0fca859 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -214,260 +214,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); -/*! \brief Compute dot product attention with packed QKV input. - * - * Computes: - * - P = Q * Transpose(K) + Bias - * - S = ScaleMaskSoftmax(P) - * - D = Dropout(S) - * - O = D * Transpose(V) - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * - * Notes: - * - * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences - * in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`. - * When the QKV format is `thd`, this tensor should follow the following rules. - * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`, - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. - * \param[in] Bias The Bias tensor. - * \param[in] SoftmaxOffset The SoftmaxOffset tensor. - * \param[in,out] S The S tensor. - * \param[out] O The output O tensor. - * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, - * e.g. M, ZInv, rng_state. - * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. - * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. - * \param[in] rng_state Seed and offset of CUDA random number generator. - * \param[in] max_seqlen Max sequence length used for computing, - * it may be >= max(seqlen_i) for i=0,...batch_size-1. - * \param[in] is_training Whether this is in training mode or inference. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -void nvte_fused_attn_fwd_qkvpacked( - const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, - bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); - -/*! \brief Compute the backward of the dot product attention with packed QKV input. - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * - * Notes: - * - * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences - * in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`. - * When the QKV format is `thd`, this tensor should follow the following rules. - * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`, - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. - * \param[in] O The O tensor from forward. - * \param[in] dO The gradient of the O tensor. - * \param[in] S The S tensor. - * \param[in,out] dP The gradient of the P tensor. - * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, - * e.g. M, ZInv, rng_state. - * \param[out] dQKV The gradient of the QKV tensor. - * \param[out] dBias The gradient of the Bias tensor. - * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. - * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. - * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. - * \param[in] max_seqlen Max sequence length used for computing, - * it may be >= max(seqlen_i) for i=0,...batch_size-1. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] deterministic Whether to execute with deterministic behaviours. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream); - -/*! \brief Compute dot product attention with packed KV input. - * - * Computes: - * - P = Q * Transpose(K) + Bias - * - S = ScaleMaskSoftmax(P) - * - D = Dropout(S) - * - O = D * Transpose(V) - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - \endverbatim - * - * Notes: - * - * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` - * help identify the correct offsets of different sequences in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. - * When the QKV format is `thd`, these tensors should follow the following rules. - * When there is no padding between sequences, the offset tensors should be equal to - * `cu_seqlens_q` and `cu_seqlens_kv` respectively. - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] Q The Q tensor, in HD layouts. - * \param[in] KV The KV tensor, in 2HD or H2D layouts. - * \param[in] Bias The Bias tensor. - * \param[in] SoftmaxOffset The SoftmaxOffset tensor. - * \param[in,out] S The S tensor. - * \param[out] O The output O tensor. - * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, - * e.g. M, ZInv, rng_state. - * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. - * \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k]. - * \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v]. - * \param[in] rng_state Seed and offset of CUDA random number generator. - * \param[in] max_seqlen_q Max sequence length used for computing for Q. - * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. - * \param[in] max_seqlen_kv Max sequence length used for computing for KV. - * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. - * \param[in] is_training Whether this is in training mode or inference. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] deterministic Whether to execute with deterministic behaviours. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); - -/*! \brief Compute the backward of the dot product attention with packed KV input. - * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - \endverbatim - * - * Notes: - * - * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` - * help identify the correct offsets of different sequences in tensors Q, K, V and O. - * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, - * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. - * When the QKV format is `thd`, these tensors should follow the following rules. - * When there is no padding between sequences, the offset tensors should be equal to - * `cu_seqlens_q` and `cu_seqlens_kv` respectively. - * When there is padding between sequences, users are responsible to adjust the offsets as needed. - * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have - * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`. - * - * \param[in] Q The Q tensor, in HD layouts. - * \param[in] KV The KV tensor, in H2D or 2HD layouts. - * \param[in] O The O tensor from forward. - * \param[in] dO The gradient of the O tensor. - * \param[in] S The S tensor. - * \param[in,out] dP The gradient of the P tensor. - * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, - * e.g. M, ZInv, rng_state. - * \param[out] dQ The gradient of the Q tensor. - * \param[out] dKV The gradient of the KV tensor. - * \param[out] dBias The gradient of the Bias tensor. - * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. - * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. - * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. - * \param[in] max_seqlen_q Max sequence length used for computing for Q. - * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. - * \param[in] max_seqlen_kv Max sequence length used for computing for KV. - * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. - * \param[in] attn_scale Scaling factor for Q * K.T. - * \param[in] dropout Dropout probability. - * \param[in] qkv_layout QKV tensor's layout. - * \param[in] bias_type Bias type. - * \param[in] attn_mask_type Attention mask type. - * \param[in] softmax_type Attention softmax type. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). - * \param[in] deterministic Whether to execute with deterministic behaviours. - * \param[in] workspace Workspace tensor. - * \param[in] stream CUDA stream used for this operation. - */ -void nvte_fused_attn_bwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); - /*! \brief Compute dot product attention with separate Q, K and V. * * Computes: diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 9277569e11..888decc7b6 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -122,20 +122,11 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; @@ -155,7 +146,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -173,36 +163,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - ragged_offset_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, - query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported QKVLayout."); - } + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -288,48 +256,40 @@ static void FusedAttnForwardImpl( /* Call the underlying NVTE API */ auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + TensorWrapper q_tensor, k_tensor, v_tensor; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + auto stride = typeToSize(dtype) * attn_heads * qk_head_dim; + q_tensor = TensorWrapper(q, q_shape, dtype); + k_tensor = + TensorWrapper(static_cast(static_cast(q) + stride), k_shape, dtype); + v_tensor = + TensorWrapper(static_cast(static_cast(q) + stride * 2), v_shape, dtype); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, - workspace_tensor.data(), stream); + auto stride = typeToSize(dtype) * num_gqa_groups * qk_head_dim; + q_tensor = TensorWrapper(q, q_shape, dtype); + k_tensor = TensorWrapper(k, k_shape, dtype); + v_tensor = + TensorWrapper(static_cast(static_cast(k) + stride), v_shape, dtype); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + q_tensor = TensorWrapper(q, q_shape, dtype); + k_tensor = TensorWrapper(k, k_shape, dtype); + v_tensor = TensorWrapper(v, v_shape, dtype); } else { NVTE_ERROR("Unsupported qkv_layout."); } + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, workspace_tensor.data(), stream); + nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -411,24 +371,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); @@ -447,7 +396,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -468,42 +416,17 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported qkv_layout."); - } + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -548,76 +471,76 @@ static void FusedAttnBackwardImpl( softmax_aux, rng_state, bias); /* Call the underly NVTE API */ + TensorWrapper q_tensor, dq_tensor, k_tensor, dk_tensor, v_tensor, dv_tensor; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); + q_tensor = TensorWrapper(q, q_shape, dtype); + dq_tensor = TensorWrapper(dq, q_shape, dtype); + + auto stride = typeToSize(dtype) * attn_heads * qk_head_dim; + k_tensor = + TensorWrapper(static_cast(static_cast(q) + stride), k_shape, dtype); + dk_tensor = + TensorWrapper(static_cast(static_cast(dq) + stride), k_shape, dtype); + v_tensor = + TensorWrapper(static_cast(static_cast(q) + stride * 2), v_shape, dtype); + dv_tensor = + TensorWrapper(static_cast(static_cast(dq) + stride * 2), v_shape, dtype); + if (is_ragged) { + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), stream); } - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, deterministic, - workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); + q_tensor = TensorWrapper(q, q_shape, dtype); + dq_tensor = TensorWrapper(dq, q_shape, dtype); + k_tensor = TensorWrapper(k, k_shape, dtype); + dk_tensor = TensorWrapper(dk, k_shape, dtype); + + auto stride = typeToSize(dtype) * num_gqa_groups * qk_head_dim; + v_tensor = + TensorWrapper(static_cast(static_cast(k) + stride), v_shape, dtype); + dv_tensor = + TensorWrapper(static_cast(static_cast(dk) + stride), v_shape, dtype); + if (is_ragged) { + auto kv_shape = + std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), stream); } - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, deterministic, - workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); + q_tensor = TensorWrapper(q, q_shape, dtype); + dq_tensor = TensorWrapper(dq, q_shape, dtype); + k_tensor = TensorWrapper(k, k_shape, dtype); + dk_tensor = TensorWrapper(dk, k_shape, dtype); + v_tensor = TensorWrapper(v, v_shape, dtype); + dv_tensor = TensorWrapper(dv, v_shape, dtype); + if (is_ragged) { cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, deterministic, - workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream); + nvte_tensor_pack_destroy(&aux_input_tensors); }