diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 80a8e4af4d..bb37575d10 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 80a8e4af4d89d33a2c59d51fcf9fda1c9d368cd4 +Subproject commit bb37575d103b9974bc619a193dc1a96d835dc117 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index d490c235bb..a272680ec2 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -249,6 +249,7 @@ def run_dpa_with_cp( attn_mask_type=config.attn_mask_type, window_size=config.window_size, softmax_type=config.softmax_type, + return_max_score=config.return_max_score, ).cuda() if config.softmax_type != "vanilla": core_attn.softmax_offset.requires_grad = True @@ -309,6 +310,7 @@ def run_dpa_with_cp( fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) else: fp8_context = nullcontext() + max_score = None with fp8_context: # q, k, v, out in FP8; dout in F16 out = core_attn( @@ -323,6 +325,8 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ) + if config.return_max_score: + out, max_score = out if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) @@ -401,6 +405,7 @@ def run_dpa_with_cp( fp8_context = nullcontext() # run attention + max_score_ = None with fp8_context: # q, k, v, out in FP8; dout in F16 out_ = core_attn( @@ -415,6 +420,8 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ) + if config.return_max_score: + out_, max_score_ = out_ if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) @@ -496,15 +503,15 @@ def run_dpa_with_cp( ) atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] - tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] - names = ["out", "dq", "dk", "dv", "d_softmax_offset"] + tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_score_] + tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_score] + names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_score"] names_cp = [x + "_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names] is_fp8 = dtype == "fp8" for i, t in enumerate(tensors_no_cp): if t is not None: - if "softmax_offset" not in names[i]: + if "softmax_offset" not in names[i] and "max_score" not in names[i]: if qkv_format == "bshd": compare_and_assert( t[:, 0], diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index e3a4de73b0..a68ecde51f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -166,7 +166,7 @@ def test_dot_product_attention( # UnfusedDotProductAttention backend if unfused_attn_supported: - unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( + unfused_attn_fwd, unfused_max_score, unfused_attn_bwd = _run_dot_product_attention( dtype, config, "UnfusedDotProductAttention", @@ -180,7 +180,7 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: if len(fused_attn_backends) == 1: - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -192,7 +192,7 @@ def test_dot_product_attention( ) if len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -203,7 +203,7 @@ def test_dot_product_attention( is_training, ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( + fused_attn_fwd_1, fused_max_score_1, fused_attn_bwd_1 = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -216,7 +216,7 @@ def test_dot_product_attention( # FlashAttention backend if flash_attn_supported: - flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( + flash_attn_fwd, flash_max_score, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", @@ -232,16 +232,22 @@ def test_dot_product_attention( if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) + if config.return_max_score: + torch.testing.assert_close(flash_max_score, unfused_max_score, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) if unfused_attn_supported and fused_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + if config.return_max_score: + torch.testing.assert_close(fused_max_score, unfused_max_score, **tols) for i, _ in enumerate(unfused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) + if config.return_max_score: + torch.testing.assert_close(fused_max_score, flash_max_score, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) if fused_attn_supported and len(fused_attn_backends) == 2: @@ -260,6 +266,40 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_max_score = { + # test: ModelConfig(b, sq, hq, dqk) + "max_score_1_0": ModelConfig(8, 128, 16, 64), + "max_score_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), + "max_score_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), + "max_score_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), + "max_score_3_0": ModelConfig( + 8, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal" + ), + "max_score_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048), + "max_score_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048), + "max_score_4_1": ModelConfig( + 8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias" + ), + "max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), + "max_score_5_1": ModelConfig( + 8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0) + ), + "max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), + "max_score_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_max_score]) +@pytest.mark.parametrize("model", model_configs_max_score.keys()) +def test_dpa_max_score(dtype, model_configs, model): + """Test DotProductAttention module with checkpointing""" + config = model_configs[model] + config.return_max_score = True + test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) + + model_configs_softmax = { # test: ModelConfig(b, sq, hq, dqk) "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), @@ -1065,6 +1105,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type=config.attn_type, softmax_type=config.softmax_type, + return_max_score=config.return_max_score, ).to(dtype=dtype, device="cuda") if not is_training: block = block.eval() @@ -1102,16 +1143,20 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: alibi_slopes=alibi_slopes, fast_zero_fill=True, ) + max_score = None + if config.return_max_score: + out, max_score = out if is_training: out.backward(d_out) + d_softmax_offset = None if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return out, (q.grad, k.grad, v.grad, d_softmax_offset) + return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None, d_softmax_offset) + return out, max_score, (None, None, None, d_softmax_offset) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1140,14 +1185,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) + return ( + out_orig, + max_score, + (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset), + ) else: - return out_orig, (None, None, None, d_softmax_offset) + return out_orig, max_score, (None, None, None, d_softmax_offset) else: if is_training: - return out, (q.grad, k.grad, v.grad, d_softmax_offset) + return out, max_score, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None, d_softmax_offset) + return out, max_score, (None, None, None, d_softmax_offset) model_configs_te_layer = { diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0f00b8b0ef..05585d3462 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: ModelConfig(b, sq, hq, dqk) - "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA - "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_score=True), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_score=True), # MHA "cp_1_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA @@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index d77256b7f9..86ffebd2b7 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -205,6 +205,7 @@ def __init__( window_size: Tuple[int, int] = (-1, -1), context_parallel: bool = False, cp_comm_type: str = "p2p", + return_max_score=False, total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, @@ -233,6 +234,7 @@ def __init__( self.window_size = check_set_window_size(self.attn_mask_type, window_size) self.context_parallel = context_parallel self.cp_comm_type = cp_comm_type + self.return_max_score = return_max_score self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers @@ -318,6 +320,7 @@ def test(): is_training=is_training, inference_params=inference_params, softmax_type=config.softmax_type, + return_max_score=config.return_max_score, ) ( use_flash_attention, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 77cd8d235a..1e4cec2dfd 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool return_max_score) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000)) { + (cudnn_runtime_version != 91000) && !return_max_score) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) { + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_score) { flag_m512 = true; } if ( @@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, float attn_scale, - float dropout, NVTE_QKV_Layout qkv_layout, + size_t max_seqlen, bool is_training, bool return_max_score, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, @@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_score); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + b, h, max_seqlen, d, t, is_training, return_max_score, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { @@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_score); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked( #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, + output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -832,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked( } } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_score, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_score); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_score, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index ba0f845789..590b50ad4a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -53,10 +53,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + bool return_max_score, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, @@ -102,36 +102,41 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + bool generate_stats = !return_max_score; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - num_pages_k, - num_pages_v, - page_size_k, - page_size_v, - max_pages_per_seq_k, - max_pages_per_seq_v, - bias_b, - bias_h, - scaling_factor, - is_training, - dropout_probability, - layout, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - true, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET}; + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + num_pages_k, + num_pages_v, + page_size_k, + page_size_v, + max_pages_per_seq_k, + max_pages_per_seq_v, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + true, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + generate_stats, + return_max_score, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -141,7 +146,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // V std::shared_ptr, // attn_scale std::shared_ptr, // O - std::shared_ptr, // Stats + std::shared_ptr, // S1 + std::shared_ptr, // S2 std::shared_ptr, // bias std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q @@ -244,6 +250,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") .set_is_inference(false) + .set_generate_stats(generate_stats) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); @@ -317,7 +324,23 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_sink_token(softmax_offset); } - auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); + std::shared_ptr Max, Sum_Exp; + if (return_max_score) { + Max = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Max") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_logit_max(Max); + Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Sum_Exp") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_score_sum_exp(Sum_Exp); + } + + auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, @@ -332,17 +355,19 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { - offset_stats = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); - } else { - Stats->set_stride({h * s_q, s_q, 1, 1}); + if (!return_max_score) { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); + } } std::tuple, // Q @@ -351,7 +376,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // attn_scale std::shared_ptr> // O key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); - auto Stats_tuple = std::make_tuple(Stats); + auto Stats_tuple = + generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto softmax_offset_tuple = is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); @@ -384,7 +410,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv, + auto [mha_graph, Q, K, V, attn_scale, O, S1, S2, bias, softmax_offset, seq_q, seq_kv, page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); @@ -417,9 +443,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // Build variant pack std::unordered_map, void *> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, - {V, devPtrV}, {attn_scale, &scaling_factor}, - {O, devPtrO}, {Stats, devPtrSoftmaxStats}}; + {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &scaling_factor}, + {O, devPtrO}, {S1, devPtrS1}}; + + if (return_max_score) { + variant_pack[S2] = devPtrS2; + } if (is_bias) { variant_pack[bias] = devPtrBias; @@ -561,35 +590,39 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - 0, - 0, - 0, - 0, - 0, - 0, - bias_b, - bias_h, - scaling_factor, - true, - dropout_probability, - layout, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - deterministic, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET}; + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + 0, + 0, + 0, + 0, + 0, + 0, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + deterministic, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + true, + true, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1001,12 +1034,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool is_training, bool return_max_score, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_QKV->data.dtype; @@ -1037,7 +1071,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( } void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; @@ -1051,14 +1086,34 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; + if (return_max_score) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Max->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_Max->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1080,8 +1135,15 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; + if (return_max_score) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1105,11 +1167,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, - nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1221,14 +1283,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_score, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1260,7 +1323,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( } void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1285,14 +1349,34 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + if (return_max_score) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1314,8 +1398,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; + if (return_max_score) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1340,11 +1431,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1471,14 +1563,14 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1488,7 +1580,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; @@ -1525,14 +1618,34 @@ void fused_attn_arbitrary_seqlen_fwd( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + if (return_max_score) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1554,8 +1667,15 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; + if (return_max_score) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1580,11 +1700,12 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index b9658b0530..094b04da5c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -20,12 +20,13 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + bool is_training, bool return_max_score, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, @@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_score, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, @@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + bool return_max_score, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 21c544491a..b5bd1b4109 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1710,7 +1710,9 @@ void fused_attn_fp8_fwd_impl_v1( qkv_tensor_type, o_tensor_type, cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET}; + cudnn_frontend::DataType_t::NOT_SET, + false, + true}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -2038,7 +2040,9 @@ void fused_attn_fp8_bwd_impl_v1( qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type}; + dqkv_tensor_type, + false, + true}; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index f03774f8ed..b3cd815ddd 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -115,20 +115,23 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t o_tensor_type; cudnn_frontend::DataType_t do_tensor_type; cudnn_frontend::DataType_t dqkv_tensor_type; + bool generate_stats; + bool generate_max_sum_exp; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type) < + o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_stats, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type); + rhs.dqkv_tensor_type, rhs.generate_stats, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index a150978c4a..8f03d5e187 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -190,29 +190,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * - * \param[in] is_training Whether the model is in training mode. - * \param[in] q_dtype The data type of Tensor Q. - * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] qkv_layout The layout of Tensors Q, K, V. - * \param[in] bias_type The attention bias type. - * \param[in] attn_mask_type The attention mask type. - * \param[in] softmax_type The attention softmax type. - * \param[in] dropout The dropout probability. - * \param[in] num_attn_heads The number of heads in Q. - * \param[in] num_gqa_groups The number of heads in K, V. - * \param[in] max_seqlen_q The sequence length of Q. - * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim_qk The head dimension of Q, K. - * \param[in] head_dim_v The head dimension of V. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). + * \param[in] is_training Whether the model is in training mode. + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool return_max_score); /*! \brief Compute dot product attention with packed QKV input. * @@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked( - const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, - bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_score, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] max_seqlen_kv Max sequence length used for computing for KV. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + size_t max_seqlen_kv, bool is_training, bool return_max_score, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); @@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_score Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. @@ -531,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_score, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 9277569e11..ffc0706fe7 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); return backend; } @@ -179,17 +180,18 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), + nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), @@ -197,8 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, + kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); @@ -276,7 +278,8 @@ static void FusedAttnForwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -294,7 +297,7 @@ static void FusedAttnForwardImpl( nvte_fused_attn_fwd_qkvpacked( qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, + q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { @@ -308,8 +311,8 @@ static void FusedAttnForwardImpl( s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, + q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; @@ -323,7 +326,7 @@ static void FusedAttnForwardImpl( dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else { @@ -542,7 +545,8 @@ static void FusedAttnBackwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 0ddb261d2e..d9681d04fa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -201,6 +201,7 @@ def __init__( attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, softmax_type: str = "vanilla", + return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -209,6 +210,7 @@ def __init__( self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number self.softmax_type = softmax_type + self.return_max_score = return_max_score def mask_func(x, y): return ( @@ -217,6 +219,8 @@ def mask_func(x, y): else attention_mask_func(x, y) ) + self.mask_func = mask_func + self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func) # Dropout. Note that for a single iteration, this layer will generate @@ -426,6 +430,17 @@ def forward( matmul_result, None, None, dP_quantizer, "dP_quantizer", None ) + # max attention score + max_score = None + if self.return_max_score: + # matmul_result [b, np, sq, dk], max_score [np] + max_score = matmul_result + if attn_mask_type != "no_mask": + max_score = self.mask_func(matmul_result, attention_mask) + with self.attention_dropout_ctx(): + max_score = self.attention_dropout(max_score) + max_score = torch.amax(max_score, dim=(0, 2, 3)) + # add attention sink to the last column: [b, np, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( @@ -529,6 +544,9 @@ def forward( if fp8_output: context_layer = O_quantizer(context_layer) + if self.return_max_score: + return context_layer, max_score + return context_layer @@ -799,6 +817,7 @@ def forward( batch_size * context_len, ) + max_score = None use_flash_attn_3 = False if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): use_flash_attn_3 = True @@ -1067,6 +1086,7 @@ def forward( softmax_offset, fp8_output, layer_number, + return_max_score, ): # pylint: disable=missing-function-docstring @@ -1102,6 +1122,7 @@ def forward( # FP8 attention: torch.float16 or torch.bfloat16 out_nominal_dtype = q.dtype + max_score = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] @@ -1129,7 +1150,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1205,7 +1226,7 @@ def forward( qkvo_tensors = (q, k, v, out) else: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1233,6 +1254,7 @@ def forward( window_size, rng_gen, softmax_offset, + return_max_score, ) out = out_ out_ret = out_ @@ -1304,10 +1326,10 @@ def forward( ctx.use_FAv2_bwd = use_FAv2_bwd ctx.deterministic = deterministic - return out_ret + return out_ret, max_score @staticmethod - def backward(ctx, d_out): + def backward(ctx, d_out, *args): # pylint: disable=missing-function-docstring # d_out is expected to be in FP8 if is_output_fp8=True, @@ -1551,6 +1573,7 @@ def backward(ctx, d_out): d_softmax_offset, None, None, + None, ) @@ -1591,6 +1614,7 @@ def __init__( layer_number: Optional[int] = None, deterministic: bool = False, softmax_type: str = "vanilla", + return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -1604,6 +1628,7 @@ def __init__( self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic self.softmax_type = softmax_type + self.return_max_score = return_max_score def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1823,6 +1848,7 @@ def forward( softmax_offset=softmax_offset, fp8_output=fp8_output, layer_number=self.layer_number, + return_max_score=self.return_max_score, ) else: with self.attention_dropout_ctx(): @@ -1858,7 +1884,11 @@ def forward( softmax_offset, fp8_output, self.layer_number, + self.return_max_score, ) + if self.return_max_score: + # ...hd -> ...(hd) + return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) - return output.view(*output.shape[:-2], -1) + return output[0].view(*output[0].shape[:-2], -1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d1374e949e..6a8659cff1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -617,6 +617,7 @@ def cp_p2p_fwd_fused_attn( rank, step, cp_size, + return_max_score, q_part, k_part, v_part, @@ -693,7 +694,7 @@ def cp_p2p_fwd_fused_attn( fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step - out_per_step, aux_ctx_tensors = fused_attn_fwd( + out_per_step, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q_, max_seqlen_kv_, @@ -713,6 +714,7 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_q_padded=cu_seqlens_q_padded_, cu_seqlens_kv_padded=cu_seqlens_kv_padded_, **fp8_meta_kwargs, + return_max_score=return_max_score, ) if fp8: @@ -721,7 +723,7 @@ def cp_p2p_fwd_fused_attn( softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None - return out_per_step, softmax_lse_per_step, rng_states, attn_bias + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, max_score def cp_p2p_fwd_flash_attn( @@ -1096,6 +1098,7 @@ def forward( use_flash_attn_3, fp8_output, layer_number, + return_max_score, ): # pylint: disable=missing-function-docstring @@ -1156,6 +1159,8 @@ def forward( amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] + max_score_per_step = [None for _ in range(cp_size)] + max_score = None assert isinstance(k, q.__class__) and isinstance( v, q.__class__ @@ -1244,6 +1249,10 @@ def forward( q_f16 = q if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if return_max_score: + max_score_per_step = torch.empty( + (cp_size, q.shape[-2]), dtype=q.dtype, device=q.device + ) # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( @@ -1418,6 +1427,7 @@ def forward( rank, i, cp_size, + return_max_score, ] else: flash_attn_inputs = [ @@ -1462,6 +1472,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1488,6 +1499,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1514,6 +1526,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1541,6 +1554,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_score_per_step[i], ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1600,11 +1614,20 @@ def forward( softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), softmax_lse_per_step[i - 1], ) + if return_max_score: + if i == 1: + max_score = torch.clone(max_score_per_step[0]) + else: + max_score = torch.maximum(max_score, max_score_per_step[i - 1]) if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + if return_max_score: + torch.distributed.all_reduce( + max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) second_half_lse_seqlen = None if causal and rank < (cp_size - 1): @@ -1682,6 +1705,10 @@ def forward( elif qkv_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + if return_max_score: + max_score = flash_attn_a2a_communicate_softmax_offset( + max_score, 0, cp_size_a2a, cp_group_a2a, cp_stream, False + ) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) @@ -1811,10 +1838,10 @@ def forward( nvtx_range_pop(f"{nvtx_label}") - return out_ret + return out_ret, max_score @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *args): # pylint: disable=missing-function-docstring # add NVTX range @@ -2522,6 +2549,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -2581,6 +2609,7 @@ def forward( cp_group, cp_stream, use_flash_attn_3, + return_max_score, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") @@ -2682,6 +2711,8 @@ def forward( softmax_lse_per_step = [None, None] rng_states = [None, None] out = torch.empty_like(q) + max_score_per_step = [None, None] + max_score = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -2712,7 +2743,11 @@ def forward( # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( + ( + out_per_step[i], + [softmax_lse_per_step[i], rng_states[i]], + max_score_per_step[i], + ) = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv_, @@ -2732,6 +2767,7 @@ def forward( cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], window_size=window_size_per_step[i], + return_max_score=return_max_score, ) else: fa_forward_args_thd = get_fa_args( @@ -2767,14 +2803,22 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + if return_max_score and i == 0: + max_score = torch.clone(max_score_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) + if return_max_score: + max_score = torch.maximum(max_score, max_score_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) + if return_max_score: + torch.distributed.all_reduce( + max_score, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) if use_fused_attention: if qkv_format == "bshd": @@ -2811,10 +2855,10 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") - return out + return out, max_score @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") cp_size = get_distributed_world_size(ctx.cp_group) @@ -3035,6 +3079,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3075,6 +3120,7 @@ def forward( softmax_type, softmax_offset, fp8_output, + return_max_score, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3158,6 +3204,7 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] fwd_nominal_dtype = q.dtype fused_attn_backend = None + max_score = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers) @@ -3203,7 +3250,7 @@ def forward( Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, max_score = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -3226,6 +3273,7 @@ def forward( **fp8_meta_kwargs, softmax_type=softmax_type, softmax_offset=softmax_offset, + return_max_score=return_max_score, ) if isinstance(out_, Float8Tensor): out_fp8 = out_ @@ -3276,6 +3324,10 @@ def forward( out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) + if return_max_score: + max_score = flash_attn_a2a_communicate_softmax_offset( + max_score, 0, cp_size, cp_group, cp_stream, False + ) if use_fused_attention: if qkv_format == "bshd": @@ -3362,10 +3414,10 @@ def forward( ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") - return out_ret + return out_ret, max_score @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") cp_size = get_distributed_world_size(ctx.cp_group) @@ -3601,6 +3653,7 @@ def backward(ctx, dout): None, d_softmax_offset, None, + None, ) @@ -3637,6 +3690,7 @@ def attn_forward_func_with_cp( softmax_offset=None, fp8_output=False, layer_number=1, + return_max_score=False, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3798,12 +3852,13 @@ def attn_forward_func_with_cp( use_flash_attn_3, fp8_output, layer_number, + return_max_score, ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream, use_flash_attn_3] + args += [window_size, cp_group, cp_stream, use_flash_attn_3, return_max_score] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": args += [ @@ -3817,6 +3872,7 @@ def attn_forward_func_with_cp( softmax_type, softmax_offset, fp8_output, + return_max_score, ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index df96067d65..c980e6bfcf 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -255,6 +255,11 @@ class DotProductAttention(TransformerEngineBaseModule): where alpha is a learnable parameter in shape [h]. 'off-by-one' and 'learnable' softmax types are also called sink attention ('zero sink' and 'learnable sink'). + return_max_score: Optional[bool], default = `False` + If true, returns the maximum attention score, max_score = max(S), where + S = Q*K^T and in shape [b, h, s_q, s_kv]. max_score can be used to rescale + the Q and K projection weights in a MuonClip optimizer (see + `Muon is Scalable for LLM Training `_). Parallelism parameters ---------------------- @@ -311,6 +316,7 @@ def __init__( cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", + return_max_score: Optional[bool] = False, ) -> None: super().__init__() @@ -394,6 +400,7 @@ def __init__( self.attention_type = attention_type self.attention_dropout = attention_dropout + self.return_max_score = return_max_score self.softmax_type = softmax_type if self.softmax_type == "vanilla": @@ -431,6 +438,7 @@ def __init__( deterministic=self.deterministic, **attn_kwargs, softmax_type=self.softmax_type, + return_max_score=self.return_max_score, ) self.unfused_attention = UnfusedDotProductAttention( @@ -439,6 +447,7 @@ def __init__( **attn_kwargs, layer_number=layer_number, softmax_type=self.softmax_type, + return_max_score=self.return_max_score, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8b26a1760d..75aa855590 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -229,6 +229,8 @@ class AttentionParams: Inference-related parameters. See InferenceParams for details. softmax_type: str, default = "vanilla" The type of softmax operation. See DotProductAttention for details. + return_max_score: bool, default = `False` + Whether to output max_score. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -257,6 +259,7 @@ class AttentionParams: fp8_meta: Union[Dict[str, Any], None] = None inference_params: Optional[InferenceParams] = None softmax_type: str = "vanilla" + return_max_score: bool = False def __eq__(self, other): """ @@ -330,6 +333,7 @@ def get_attention_backend( fp8_meta = attention_params.fp8_meta inference_params = attention_params.inference_params softmax_type = attention_params.softmax_type + return_max_score = attention_params.return_max_score # Run config logger = logging.getLogger("DotProductAttention") @@ -477,6 +481,24 @@ def get_attention_backend( logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False + # Filter: Return max_score + if return_max_score: + if context_parallel: + use_flash_attention = False + use_unfused_attention = False + logger.debug( + "Disabling FlashAttention and UnfusedAttention for max_score with context" + " parallelism" + ) + if use_flash_attention: + use_flash_attention = False + logger.debug("Disabling FlashAttention for max_score") + if fp8 and fp8_meta["recipe"].fp8_dpa: + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + logger.debug("Disabling all backends for max_score with context parallelism in FP8") + # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- @@ -913,6 +935,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt head_dim_v, window_size[0], window_size[1], + return_max_score, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 94a12c4a09..1900b56789 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -139,6 +139,7 @@ def fused_attn_fwd( window_size: Tuple[int, int] = (-1, -1), rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, + return_max_score: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -216,6 +217,8 @@ def fused_attn_fwd( softmax_offset: torch.Tensor, default = None softmax offset tensor in shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. + return_max_score: bool, default = False + whether to return the maximum attention score Returns ---------- @@ -246,6 +249,7 @@ def fused_attn_fwd( rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen state of the random number generator; [seed, offset], dtype uint64 + max_score: float if return_max_score = True, otherwise None """ if attn_scale is None: @@ -315,10 +319,20 @@ def fused_attn_fwd( softmax_offset, rng_gen, rng_elts_per_thread, + return_max_score, ) + if return_max_score: + # output_tensors: out [b, sq, h, d] or [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + stats = output_tensors[1] + torch.log(output_tensors[2]) + # Max [b, h, sq, 1] -> max_score [h] + max_score = torch.amax(output_tensors[1], dim=(0, 2, 3)).to(dtype=output_tensors[0].dtype) + aux_ctx_tensors = [stats] + aux_ctx_tensors.extend(output_tensors[3:]) + return output_tensors[0], aux_ctx_tensors, max_score + # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] + return output_tensors[0], output_tensors[1:], None def fused_attn_bwd( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d86a96959c..9d0f54a3d0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool return_max_score); std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, @@ -94,7 +94,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread); + size_t rng_elts_per_thread, bool return_max_score); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 344bc4ab0b..b0535676bf 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,11 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool return_max_score) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); + max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, + return_max_score); return fused_attention_backend; } @@ -106,7 +107,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread) { + size_t rng_elts_per_thread, bool return_max_score) { auto none = py::none(); // create QKV tensor wrappers @@ -228,8 +229,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -249,7 +251,9 @@ std::vector fused_attn_fwd( }; // allocate memory for nvte_aux_tensor_pack.tensors // f16_max512 : S [b, h, sq, skv] - // f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // f16_arbitrary: + // return_max_score=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // return_max_score=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv] // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] size_t i = 0; at::Tensor output_tensor; @@ -259,7 +263,7 @@ std::vector fused_attn_fwd( static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); // fp8 has an additional softmax stats tensor, ZInv - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + if (return_max_score || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); @@ -285,8 +289,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_score, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory