From 19c8ccc7d1e5101d5093658a89653bd6ffe4cdbe Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Wed, 9 Oct 2024 07:14:47 -0700 Subject: [PATCH] PR #15331: Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) Imported from GitHub PR https://github.com/openxla/xla/pull/15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw : Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw : Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw : clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang : fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw : Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6 PiperOrigin-RevId: 684025541 --- xla/service/gpu/cublas_cudnn.cc | 14 +- xla/service/gpu/cublas_cudnn.h | 2 + xla/service/gpu/tests/gpu_fused_mha_test.cc | 258 ++++++++++++++++++ .../transforms/cudnn_custom_call_compiler.cc | 71 ++++- xla/stream_executor/cuda/cuda_dnn.cc | 196 +++++++++++++ xla/stream_executor/cuda/cuda_dnn.h | 10 + 6 files changed, 544 insertions(+), 7 deletions(-) diff --git a/xla/service/gpu/cublas_cudnn.cc b/xla/service/gpu/cublas_cudnn.cc index 18e131eee8f108..a9d94e8ed8ae33 100644 --- a/xla/service/gpu/cublas_cudnn.cc +++ b/xla/service/gpu/cublas_cudnn.cc @@ -128,11 +128,13 @@ bool IsCustomCallToDnnNorm(const HloInstruction& hlo) { } bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo) { - if (hlo.opcode() != HloOpcode::kCustomCall) { - return false; - } - const auto& target = hlo.custom_call_target(); - return target == kCudnnfMHASoftmaxF8CallTarget; + return hlo.opcode() == HloOpcode::kCustomCall && + hlo.custom_call_target() == kCudnnfMHASoftmaxF8CallTarget; +} + +bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo) { + return hlo.opcode() == HloOpcode::kCustomCall && + hlo.custom_call_target() == kCudnnfMHASoftmaxBackwardF8CallTarget; } bool IsFwdCustomCallTofMHA(const HloInstruction& hlo) { @@ -169,7 +171,7 @@ bool IsCustomCallTofMHA(const HloInstruction& hlo) { } bool IsCustomCallTofMHAF8(const HloInstruction& hlo) { - return IsFwdCustomCallTofMHAF8(hlo); + return IsFwdCustomCallTofMHAF8(hlo) || IsBwdCustomCallTofMHAF8(hlo); } bool IsCubDeviceRadixSort(const HloInstruction& hlo) { diff --git a/xla/service/gpu/cublas_cudnn.h b/xla/service/gpu/cublas_cudnn.h index 9befcbb60901b1..d3f0a1ce22fdb3 100644 --- a/xla/service/gpu/cublas_cudnn.h +++ b/xla/service/gpu/cublas_cudnn.h @@ -188,6 +188,7 @@ extern const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget; extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget; extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget; // Backward calls +extern const absl::string_view kCudnnfMHASoftmaxBackwardF8CallTarget; extern const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget; extern const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget; extern const absl::string_view @@ -195,6 +196,7 @@ extern const absl::string_view extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget; bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo); +bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo); bool IsCustomCallTofMHAF8(const HloInstruction& hlo); bool IsFwdCustomCallTofMHA(const HloInstruction& hlo); bool IsBwdCustomCallTofMHA(const HloInstruction& hlo); diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index 666cc1a041e8eb..fd9b9208311886 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1813,6 +1813,264 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxDropoutBMM, Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2) { TestImpl_Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2(); } + +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + XlaBuilder builder(TestName()); + std::string hlo_string_ref = R"( + HloModule fmha_cudnn_custom_call_bwd + // Process inputs: clip, convert to f8e4m3fn, and convert back to bf16 + cast_to_representable { + // Parameters + input = bf16[1,1,128,128] parameter(0) + min_val = bf16[] parameter(1) + max_val = bf16[] parameter(2) + + // Broadcasting min and max values + min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={} + max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={} + + // Clipping the scaled input + clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input) + clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min) + + // Converting to f8e4m3fn and back to bf16 + converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped) + ROOT converted_bf16 = bf16[1,1,128,128] convert(converted_f8) + } + // Main function + ENTRY main { + // Input parameters + query = bf16[1,1,128,128] parameter(0) + key = bf16[1,1,128,128] parameter(1) + value = bf16[1,1,128,128] parameter(2) + grad_output = bf16[1,1,128,128] parameter(3) + fwd_output = bf16[1,1,128,128] parameter(4) + score = f32[1,1,128] parameter(5) + + // Constants + one_f32 = f32[] constant(1) + one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={} + min_clip_val = bf16[] constant(-448) + max_clip_val = bf16[] constant(448) + + query_processed = bf16[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable + key_processed = bf16[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable + value_processed = bf16[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable + grad_output_processed = bf16[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + fwd_output_processed = bf16[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + + // FMHA Forward Backward custom call + fmha_result = (bf16[1,1,128,128], bf16[1,1,128,128], bf16[1,1,128,128], u8[0]) custom-call( + query_processed, key_processed, value_processed, + score, fwd_output_processed, grad_output_processed + ), + custom_call_target="__cudnn$fmhaSoftmaxBackward", + operand_layout_constraints={ + bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0}, + bf16[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0}, + bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": {"17": "1", "24": "0"}, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 1.0, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["1", "1", "128", "128"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm1_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + + ROOT output = bf16[1,1,128,128] get-tuple-element(fmha_result), index=0 + })"; + + std::string hlo_string = R"( + HloModule fmha_cudnn_custom_call_bwd_f8 + // Process inputs: clip, convert to f8e4m3fn + cast_to_representable { + // Parameters + input = bf16[1,1,128,128] parameter(0) + min_val = bf16[] parameter(1) + max_val = bf16[] parameter(2) + + // Broadcasting min and max values + min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={} + max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={} + + // Clipping the scaled input + clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input) + clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min) + + // Converting to f8e4m3fn and back to bf16 + ROOT converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped) + } + + // Main function + ENTRY main { + // Input parameters + query = bf16[1,1,128,128] parameter(0) + key = bf16[1,1,128,128] parameter(1) + value = bf16[1,1,128,128] parameter(2) + grad_output = bf16[1,1,128,128] parameter(3) + fwd_output = bf16[1,1,128,128] parameter(4) + score = f32[1,1,128] parameter(5) + + // Constants + one_f32 = f32[] constant(1) + one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={} + min_clip_val = bf16[] constant(-448) + max_clip_val = bf16[] constant(448) + + query_processed = f8e4m3fn[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable + key_processed = f8e4m3fn[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable + value_processed = f8e4m3fn[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable + grad_output_processed = f8e4m3fn[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + fwd_output_processed = f8e4m3fn[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + + // FMHA Softmax Backward custom call + fmha_result = (f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128], + f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], u8[0]) custom-call( + query_processed, key_processed, value_processed, + grad_output_processed, fwd_output_processed, score, + one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, + one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, + one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast + ), + custom_call_target="__cudnn$fmhaSoftmaxBackwardF8", + operand_layout_constraints={ + f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0}, + f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0}, + f8e4m3fn[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": {"17": "1", "24": "0"}, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 1.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["1", "1", "128", "128"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm1_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + + fmha_output = f8e4m3fn[1,1,128,128] get-tuple-element(fmha_result), index=0 + ROOT output = bf16[1,1,128,128] convert(fmha_output) + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string_ref, hlo_string, + ErrorSpec{2e-1, 2e-1})); +} } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 1c01b3f47cd878..9f7668c9e226bd 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -199,7 +199,7 @@ absl::StatusOr HloCustomCallToCuDnnGraph( dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, activation, static_cast(config.fmha_scale()), dnn_mask_type)); return std::move(graph); - } else { + } else if (IsBwdCustomCallTofMHA(*custom_call)) { TF_ASSIGN_OR_RETURN( auto gpu_config, custom_call->backend_config()); @@ -314,6 +314,75 @@ absl::StatusOr HloCustomCallToCuDnnGraph( config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, dnn_mask_type, force_deterministic, sliding_window_length)); return std::move(graph); + } else { + TF_ASSIGN_OR_RETURN( + auto gpu_config, + custom_call->backend_config()); + xla::gpu::CudnnfMHABackendConfig &config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + + Shape bmm1_grad_gemm1_rhs_shape = custom_call->operand(0)->shape(); + Shape bmm1_grad_gemm2_rhs_shape = custom_call->operand(1)->shape(); + Shape bmm2_grad_gemm2_rhs_shape = custom_call->operand(2)->shape(); + + Shape fwd_output_shape = custom_call->operand(3)->shape(); + Shape d_output_shape = custom_call->operand(4)->shape(); + + Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); + + Shape d_bmm1_lhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {0}); + Shape d_bmm1_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {1}); + Shape d_bmm2_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {2}); + + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm1_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm1_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm2_rhs_shape, + config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm1_lhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm1_lhs_shape, + config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm2_rhs_shape, + config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor d_output, + MatmulTensorDescriptorFor( + d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), + RHS)); + + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs, + TensorDescriptorFor(d_bmm1_lhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs, + TensorDescriptorFor(d_bmm1_rhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs, + TensorDescriptorFor(d_bmm2_rhs_shape)); + // 3 gradients, 4 amaxs and one workspace + TF_RET_CHECK(8 == custom_call->shape().tuple_shapes().size()); + + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn_support, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, + bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs, + d_bmm1_rhs, d_bmm2_rhs, config.fmha_scale(), dnn_mask_type)); + return std::move(graph); } } diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 93bb5e46b14f38..67d51a13ce5b29 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -5280,6 +5280,202 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( #endif } +absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, + const dnn::MatmulTensorDescriptor& k_desc, + const dnn::MatmulTensorDescriptor& p_desc, + const dnn::MatmulTensorDescriptor& v_desc, + const dnn::MatmulTensorDescriptor& do_desc, + const dnn::TensorDescriptor& dq_desc, const dnn::TensorDescriptor& dk_desc, + const dnn::TensorDescriptor& dv_desc, double scale, + dnn::FMHAMaskKind mask_type) { +#if CUDNN_VERSION >= 90100 + if (VLOG_IS_ON(4)) { + VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() + << "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString() + << "\n bmm2_grad_gemm1_lhs(p): " << p_desc.ToString() + << "\n bmm2_grad_gemm2_rhs(v^t): " << v_desc.ToString() + << "\n d_output(do): " << do_desc.ToString() + << "\n d_bmm1_lhs(dq): " << dq_desc.ToString() + << "\n d_bmm1_rhs(dk): " << dk_desc.ToString() + << "\n d_bmm2_rhs(dv): " << dv_desc.ToString() + << "\n scale: " << scale; + } + using cudnn_frontend::graph::Tensor_attributes; + cudnn_frontend::graph::Graph graph; + if (!(q_desc.type() == k_desc.type() && v_desc.type() == do_desc.type() && + do_desc.type() == dq_desc.type() && dq_desc.type() == dk_desc.type() && + dk_desc.type() == dv_desc.type())) { + return absl::InternalError("Input datatypes do not match."); + } + + auto ioDataType = ToCudnnFrontendDataType(q_desc.type()); + graph.set_compute_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_io_data_type(ioDataType); + + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + + std::shared_ptr q = + graph.tensor(Tensor_attributes() + .set_name("Q") + .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(q_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr k = + graph.tensor(Tensor_attributes() + .set_name("K") + .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(k_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr v = + graph.tensor(Tensor_attributes() + .set_name("V") + .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) + .set_stride(v_desc.GetCudnnCompatibleStrides(true)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr o = + graph.tensor(Tensor_attributes() + .set_name("O") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr dO = + graph.tensor(Tensor_attributes() + .set_name("dO") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + + auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); + auto p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + p_reduction_strides[3] = 1; + std::shared_ptr Stats = + graph.tensor(Tensor_attributes() + .set_name("Stats") + .set_dim(p_reduction_dims) + .set_stride(p_reduction_strides) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); + + auto descale_q = + graph.tensor(Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); + auto descale_k = graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = graph.tensor_like(descale_q, "Descale_S"); + auto descale_o = graph.tensor_like(descale_q, "Descale_O"); + auto descale_dO = graph.tensor_like(descale_q, "Descale_dO"); + auto descale_dP = graph.tensor_like(descale_q, "Descale_dP"); + + auto scale_s = graph.tensor_like(descale_q, "Scale_S"); + auto scale_dP = graph.tensor_like(descale_q, "Scale_dP"); + auto scale_dQ = graph.tensor_like(descale_q, "Scale_dQ"); + auto scale_dK = graph.tensor_like(descale_q, "Scale_dK"); + auto scale_dV = graph.tensor_like(descale_q, "Scale_dV"); + + descale_k->set_uid(next_uid()); + descale_v->set_uid(next_uid()); + descale_s->set_uid(next_uid()); + descale_o->set_uid(next_uid()); + descale_dO->set_uid(next_uid()); + descale_dP->set_uid(next_uid()); + + scale_s->set_uid(next_uid()); + scale_dP->set_uid(next_uid()); + scale_dQ->set_uid(next_uid()); + scale_dK->set_uid(next_uid()); + scale_dV->set_uid(next_uid()); + + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL; + auto sdpa_fp8_backwards_options = + cudnn_frontend::graph::SDPA_fp8_backward_attributes() + .set_name("sdpa_fp8_backward") + .set_causal_mask(is_causal) + .set_attn_scale(scale); + + auto [dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP] = + graph.sdpa_fp8_backward(q, k, v, o, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, + scale_dP, sdpa_fp8_backwards_options); + + dQ->set_output(true) + .set_dim(dq_desc.dimensions()) + .set_stride(dq_desc.GetLogicalStrides()) + .set_name("dQ") + .set_uid(next_uid()) + .set_data_type(ioDataType); + dK->set_output(true) + .set_dim(dk_desc.dimensions()) + .set_stride(dk_desc.GetLogicalStrides()) + .set_name("dK") + .set_uid(next_uid()) + .set_data_type(ioDataType); + dV->set_output(true) + .set_dim(dv_desc.dimensions()) + .set_stride(dv_desc.GetLogicalStrides()) + .set_name("dV") + .set_uid(next_uid()) + .set_data_type(ioDataType); + Amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + + CudnnGraph cudnnGraph(std::move(graph)); + TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + dnn_support, NumericOptions{/*require_determinism=*/false, + /*allow_tf32=*/true})); + TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + + if (VLOG_IS_ON(4)) { + VLOG(4) << "\b workspace size:" << cudnnGraph.Graph().get_workspace_size(); + VLOG(4) << "\b flash attention f8 operation backward graph: " << graph; + } + + return cudnnGraph; +#else + return absl::UnimplementedError( + "Cudnn flash attention only supported with Cudnn >= 9.1.0"); +#endif +} + absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, const dnn::MatmulTensorDescriptor& k_desc, diff --git a/xla/stream_executor/cuda/cuda_dnn.h b/xla/stream_executor/cuda/cuda_dnn.h index 3eb702b9f4415e..83df67b4cdcc53 100644 --- a/xla/stream_executor/cuda/cuda_dnn.h +++ b/xla/stream_executor/cuda/cuda_dnn.h @@ -732,6 +732,16 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( const dnn::FMHAMaskKind mask_type, bool force_deterministic, const int sliding_window_length); +absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, + const dnn::MatmulTensorDescriptor& k_desc, + const dnn::MatmulTensorDescriptor& p_desc, + const dnn::MatmulTensorDescriptor& v_desc, + const dnn::MatmulTensorDescriptor& do_desc, + const dnn::TensorDescriptor& dq_desc, const dnn::TensorDescriptor& dk_desc, + const dnn::TensorDescriptor& dv_desc, double scale, + dnn::FMHAMaskKind mask_type); + } // namespace gpu } // namespace stream_executor