From 949191a4060f6f0bf534f57733202c283b845447 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 15 Aug 2024 08:37:53 -0700 Subject: [PATCH] bwd --- xla/service/gpu/cublas_cudnn.cc | 10 +- xla/service/gpu/cublas_cudnn.h | 2 + xla/service/gpu/tests/gpu_fused_mha_test.cc | 122 +++++++++++ .../transforms/cudnn_custom_call_compiler.cc | 120 ++++++++++- xla/stream_executor/cuda/cuda_dnn.cc | 198 ++++++++++++++++++ xla/stream_executor/cuda/cuda_dnn.h | 22 +- 6 files changed, 465 insertions(+), 9 deletions(-) diff --git a/xla/service/gpu/cublas_cudnn.cc b/xla/service/gpu/cublas_cudnn.cc index d25eac0c95cb70..04754f24025d4b 100644 --- a/xla/service/gpu/cublas_cudnn.cc +++ b/xla/service/gpu/cublas_cudnn.cc @@ -135,6 +135,14 @@ bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo) { return target == kCudnnfMHASoftmaxF8CallTarget; } +bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const auto& target = hlo.custom_call_target(); + return target == kCudnnfMHASoftmaxBackwardF8CallTarget; +} + bool IsFwdCustomCallTofMHA(const HloInstruction& hlo) { if (hlo.opcode() != HloOpcode::kCustomCall) { return false; @@ -169,7 +177,7 @@ bool IsCustomCallTofMHA(const HloInstruction& hlo) { } bool IsCustomCallTofMHAF8(const HloInstruction& hlo) { - return IsFwdCustomCallTofMHAF8(hlo); + return IsFwdCustomCallTofMHAF8(hlo) || IsBwdCustomCallTofMHAF8(hlo); } bool IsCubDeviceRadixSort(const HloInstruction& hlo) { diff --git a/xla/service/gpu/cublas_cudnn.h b/xla/service/gpu/cublas_cudnn.h index 24b3d4f61f82b0..a6f8e340daf656 100644 --- a/xla/service/gpu/cublas_cudnn.h +++ b/xla/service/gpu/cublas_cudnn.h @@ -188,6 +188,7 @@ extern const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget; extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget; extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget; // Backward calls +extern const absl::string_view kCudnnfMHASoftmaxBackwardF8CallTarget; extern const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget; extern const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget; extern const absl::string_view @@ -195,6 +196,7 @@ extern const absl::string_view extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget; bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo); +bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo); bool IsCustomCallTofMHAF8(const HloInstruction& hlo); bool IsFwdCustomCallTofMHA(const HloInstruction& hlo); bool IsBwdCustomCallTofMHA(const HloInstruction& hlo); diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index fa60f0d55ec765..1bc47dabde6b18 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1189,6 +1189,123 @@ class FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8 EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, ErrorSpec{1e-2, 1e-2})); } + + void TestImpl_Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 0, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0."; + } + XlaBuilder builder(TestName()); + // generate padding mask in cuDNN directly + // XLA pattern match does not support pattern matching padding mask + // so directly lower to custom call instead for reference + std::string hlo_string_ref = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0})->bf16[1,1,256,128]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true} + + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[1,1,256,128]{3,2,1,0} parameter(0) + maximum.38 = bf16[1,1,256,128]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[1,1,256,128]{3,2,1,0} minimum(broadcast.39, maximum.38) + } // clip.33 + + ENTRY main.106 { + constant.99 = f32[] constant(1) + broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} + Arg_0.1 = bf16[1,1,256,128]{3,2,1,0} parameter(0) + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[1,1,256,128]{3,2,1,0} broadcast(constant.6), dimensions={} + divide.8 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_0.1, broadcast.7) + constant.5 = bf16[] constant(-448) + constant.4 = bf16[] constant(448) + call.17 = bf16[1,1,256,128]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 + convert.18 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.17) + convert.19 = bf16[1,1,256,128]{3,2,1,0} convert(convert.18) + Arg_1.2 = bf16[1,1,256,128]{3,2,1,0} parameter(1) + divide.20 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[1,1,256,128]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 + convert.30 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.29) + convert.31 = bf16[1,1,256,128]{3,2,1,0} convert(convert.30) + Arg_2.3 = bf16[1,1,256,128]{3,2,1,0} parameter(2) + divide.32 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[1,1,256,128]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 + convert.42 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.41) + convert.43 = bf16[1,1,256,128]{3,2,1,0} convert(convert.42) + Arg_3.4 = bf16[1,1,256,128]{3,2,1,0} parameter(3) + divide.72 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_3.4, broadcast.7) + call.71 = bf16[1,1,256,128]{3,2,1,0} call(divide.72, constant.5, constant.4), to_apply=clip.33 + convert.72 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.71) + convert.73 = bf16[1,1,256,128]{3,2,1,0} convert(convert.72) + Arg_4.5 = bf16[1,1,256,128]{3,2,1,0} parameter(4) + divide.82 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_4.5, broadcast.7) + call.81 = bf16[1,1,256,128]{3,2,1,0} call(divide.82, constant.5, constant.4), to_apply=clip.33 + convert.82 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.81) + convert.83 = bf16[1,1,256,128]{3,2,1,0} convert(convert.82) + bt.0 = f32[1,1,256]{2,1,0} broadcast(constant.99), dimensions={} + custom-call.7 = (bf16[1,1,256,128]{3,1,2,0}, bf16[1,1,256,128]{3,1,2,0}, bf16[1,1,256,128]{3,1,2,0}, u8[0]{0}) custom-call(convert.19, convert.31, convert.43, bt.0, convert.83, /*index=5*/convert.73), custom_call_target="__cudnn$fmhaSoftmaxBackward", operand_layout_constraints={bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, f32[1,1,256]{2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 1.0, "dropout_rate": 0.0, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["4", "4", "16", "16"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "seed": 42, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["2"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["2"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["2"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}}} + get-tuple-element.8 = bf16[1,1,256,128]{3,1,2,0} get-tuple-element(custom-call.7), index=0 + get-tuple-element.9 = bf16[1,1,256,128]{3,1,2,0} get-tuple-element(custom-call.7), index=1 + ROOT tuple.3 = (bf16[1,1,256,128]{3,1,2,0}, bf16[1,1,256,128]{3,1,2,0}) tuple(get-tuple-element.8, get-tuple-element.9) + })"; + + std::string hlo_string = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0})->bf16[1,1,256,128]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true} + + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[1,1,256,128]{3,2,1,0} parameter(0) + maximum.38 = bf16[1,1,256,128]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[1,1,256,128]{3,2,1,0} minimum(broadcast.39, maximum.38) + } + + ENTRY main.106 { + constant.99 = f32[] constant(1) + broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} + Arg_0.1 = bf16[1,1,256,128]{3,2,1,0} parameter(0) + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[1,1,256,128]{3,2,1,0} broadcast(constant.6), dimensions={} + divide.8 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_0.1, broadcast.7) + constant.5 = bf16[] constant(-448) + constant.4 = bf16[] constant(448) + call.17 = bf16[1,1,256,128]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 + convert.18 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.17) + convert.19 = bf16[1,1,256,128]{3,2,1,0} convert(convert.18) + Arg_1.2 = bf16[1,1,256,128]{3,2,1,0} parameter(1) + divide.20 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[1,1,256,128]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 + convert.30 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.29) + convert.31 = bf16[1,1,256,128]{3,2,1,0} convert(convert.30) + Arg_2.3 = bf16[1,1,256,128]{3,2,1,0} parameter(2) + divide.32 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[1,1,256,128]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 + convert.42 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.41) + convert.43 = bf16[1,1,256,128]{3,2,1,0} convert(convert.42) + Arg_3.4 = bf16[1,1,256,128]{3,2,1,0} parameter(3) + divide.72 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_3.4, broadcast.7) + call.71 = bf16[1,1,256,128]{3,2,1,0} call(divide.72, constant.5, constant.4), to_apply=clip.33 + convert.72 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.71) + convert.73 = bf16[1,1,256,128]{3,2,1,0} convert(convert.72) + Arg_4.5 = bf16[1,1,256,128]{3,2,1,0} parameter(4) + divide.82 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_4.5, broadcast.7) + call.81 = bf16[1,1,256,128]{3,2,1,0} call(divide.82, constant.5, constant.4), to_apply=clip.33 + convert.82 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.81) + convert.83 = bf16[1,1,256,128]{3,2,1,0} convert(convert.82) + bt.0 = f32[1,1,256]{2,1,0} broadcast(constant.99), dimensions={} + custom-call.9 = (f8e4m3fn[1,1,256,128]{3,1,2,0}, f8e4m3fn[1,1,256,128]{3,1,2,0}, f8e4m3fn[1,1,256,128]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, u8[0]{0}) custom-call(convert.18, convert.30, convert.42, convert.72, convert.82, bt.0, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99), custom_call_target="__cudnn$fmhaSoftmaxBackward$f8" + get-tuple-element.10 = f8e4m3fn[1,1,256,128]{3,1,2,0} get-tuple-element(custom-call.9), index=0 + ROOT out.10 = bf16[1,1,256,128]{3,1,2,0} convert(get-tuple-element.10) + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{1e-1, 1e-1})); + } }; // BMM1 - Scale - CausalMask - Softmax - BMM2 @@ -1245,6 +1362,11 @@ XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8, Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8) { TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8(); } + +XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8, + Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8) { + TestImpl_Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8(); +} } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index be826d7a2e0a50..127998157431b8 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -159,7 +159,7 @@ absl::StatusOr HloCustomCallToCuDnnGraph( static_cast(config.fmha_scale()), dropout_rate > 0.0, dropout_rate, dnn_mask_type)); return std::move(graph); - } else if (IsFwdCustomCallTofMHAF8(*custom_call)) { + } else if (IsFwdCustomCallTofMHAF8(*custom_call)) { TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, xla::gpu::GetCudnnfMHAKind(custom_call)); TF_ASSIGN_OR_RETURN( @@ -205,7 +205,7 @@ absl::StatusOr HloCustomCallToCuDnnGraph( rhs_bmm2, output, activation, static_cast(config.fmha_scale()), dnn_mask_type)); return std::move(graph); - } else { + } else if (IsBwdCustomCallTofMHA(*custom_call)) { TF_ASSIGN_OR_RETURN( auto gpu_config, custom_call->backend_config()); @@ -322,6 +322,122 @@ absl::StatusOr HloCustomCallToCuDnnGraph( config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, dnn_mask_type, force_deterministic)); return std::move(graph); + } else { + TF_ASSIGN_OR_RETURN( + auto gpu_config, + custom_call->backend_config()); + xla::gpu::CudnnfMHABackendConfig& config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + + int input_index = 0; + Shape bmm1_grad_gemm1_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm1_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm2_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + + Shape fwd_output_shape = + custom_call->operand(input_index++)->shape(); + Shape d_output_shape = custom_call->operand(input_index++)->shape(); + + Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); + input_index++; + + TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, + GetCudnnfMHAKind(custom_call)); + std::cout << "xxxxxxxxxxxxxxx operand_count" + << custom_call->operand_count(); + std::cout << custom_call->ToString(); + TF_RET_CHECK(input_index == 6); + + int output_index = 0; + Shape d_bmm1_lhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + Shape d_bmm1_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + Shape d_bmm2_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + std::cout << "\nyyyyy operand_count" + << custom_call->shape().tuple_shapes().size() + << "; and output_index=" << output_index; + + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm1_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm1_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm2_rhs_shape, + config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm1_lhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm1_lhs_shape, + config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm2_rhs_shape, + config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor d_output, + MatmulTensorDescriptorFor( + d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), + RHS)); + + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs, + TensorDescriptorFor(d_bmm1_lhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs, + TensorDescriptorFor(d_bmm1_rhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs, + TensorDescriptorFor(d_bmm2_rhs_shape)); + + // The last one is the workspace. + TF_RET_CHECK(output_index == custom_call->shape().tuple_shapes().size() - + 5); // 4 amax and a workspace + + + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + + // GpufMHABackwardF8Descriptor descriptor = { + // kind, + // config, + // cudnn_mask_type, + // bmm1_grad_gemm1_rhs_shape, + // bmm1_grad_gemm2_rhs_shape, + // bmm2_grad_gemm1_lhs_shape, + // bmm2_grad_gemm2_rhs_shape, + // d_output_shape, + // d_bmm1_lhs_shape, + // d_bmm1_rhs_shape, + // d_bmm2_rhs_shape, + // config.bmm1_grad_gemm1_dot_dimension_numbers(), + // config.bmm1_grad_gemm2_dot_dimension_numbers(), + // config.bmm2_grad_gemm1_dot_dimension_numbers(), + // config.bmm2_grad_gemm2_dot_dimension_numbers(), + // fwd_output_shape}; + + // TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, + // GpufMHABackwardConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + std::cout << "bf get cudnn graph\n"; + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn_support, bmm1_grad_gemm1_rhs, + bmm1_grad_gemm2_rhs, bmm2_grad_gemm1_lhs, + bmm2_grad_gemm2_rhs, d_output, + d_bmm1_lhs, d_bmm1_rhs, + d_bmm2_rhs, config.fmha_scale(), dnn_mask_type)); + std::cout << "end of wkspace rewrite\n"; + return std::move(graph); } } diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 6ad78c328ccbaa..d9e8de034137cf 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -5241,10 +5241,12 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( amax_s->set_output(true) .set_dim({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_stride({1, 1, 1, 1}) .set_uid(next_uid()); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_stride({1, 1, 1, 1}) .set_uid(next_uid()); if (stats_descriptor.has_value()) { @@ -5278,6 +5280,202 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( #endif } +absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, + const dnn::MatmulTensorDescriptor& k_desc, + const dnn::MatmulTensorDescriptor& p_desc, + const dnn::MatmulTensorDescriptor& v_desc, + const dnn::MatmulTensorDescriptor& do_desc, + const dnn::TensorDescriptor& dq_desc, const dnn::TensorDescriptor& dk_desc, + const dnn::TensorDescriptor& dv_desc, double scale, + dnn::FMHAMaskKind mask_type) { +#if CUDNN_VERSION >= 90000 + if (VLOG_IS_ON(4)) { + VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() + << "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString() + << "\n bmm2_grad_gemm1_lhs(p): " << p_desc.ToString() + << "\n bmm2_grad_gemm2_rhs(v^t): " << v_desc.ToString() + << "\n d_output(do): " << do_desc.ToString() + << "\n d_bmm1_lhs(dq): " << dq_desc.ToString() + << "\n d_bmm1_rhs(dk): " << dk_desc.ToString() + << "\n d_bmm2_rhs(dv): " << dv_desc.ToString() + << "\n scale: " << scale; + } + using cudnn_frontend::graph::Tensor_attributes; + cudnn_frontend::graph::Graph graph; + if (!(q_desc.type() == k_desc.type() && v_desc.type() == do_desc.type() && + do_desc.type() == dq_desc.type() && dq_desc.type() == dk_desc.type() && + dk_desc.type() == dv_desc.type())) { + return absl::InternalError("Input datatypes do not match."); + } + + auto ioDataType = ToCudnnFrontendDataType(q_desc.type()); + graph.set_compute_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_io_data_type(ioDataType); + + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + + std::shared_ptr q = + graph.tensor(Tensor_attributes() + .set_name("Q") + .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(q_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr k = + graph.tensor(Tensor_attributes() + .set_name("K") + .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(k_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr v = + graph.tensor(Tensor_attributes() + .set_name("V") + .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) + .set_stride(v_desc.GetCudnnCompatibleStrides(true)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr o = + graph.tensor(Tensor_attributes() + .set_name("O") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr dO = + graph.tensor(Tensor_attributes() + .set_name("dO") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + + auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); + auto p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + p_reduction_strides[3] = 1; + std::shared_ptr Stats = + graph.tensor(Tensor_attributes() + .set_name("Stats") + .set_dim(p_reduction_dims) + .set_stride(p_reduction_strides) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); + + auto descale_q = + graph.tensor(Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); + auto descale_k = graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = graph.tensor_like(descale_q, "Descale_S"); + auto descale_o = graph.tensor_like(descale_q, "Descale_O"); + auto descale_dO = graph.tensor_like(descale_q, "Descale_dO"); + auto descale_dP = graph.tensor_like(descale_q, "Descale_dP"); + + auto scale_s = graph.tensor_like(descale_q, "Scale_S"); + auto scale_dP = graph.tensor_like(descale_q, "Scale_dP"); + auto scale_dQ = graph.tensor_like(descale_q, "Scale_dQ"); + auto scale_dK = graph.tensor_like(descale_q, "Scale_dK"); + auto scale_dV = graph.tensor_like(descale_q, "Scale_dV"); + + descale_k->set_uid(next_uid()); + descale_v->set_uid(next_uid()); + descale_s->set_uid(next_uid()); + descale_o->set_uid(next_uid()); + descale_dO->set_uid(next_uid()); + descale_dP->set_uid(next_uid()); + + scale_s->set_uid(next_uid()); + scale_dP->set_uid(next_uid()); + scale_dQ->set_uid(next_uid()); + scale_dK->set_uid(next_uid()); + scale_dV->set_uid(next_uid()); + + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL; + auto sdpa_fp8_backwards_options = + cudnn_frontend::graph::SDPA_fp8_backward_attributes() + .set_name("sdpa_fp8_backward") + .set_causal_mask(is_causal) + .set_attn_scale(scale); + + auto [dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP] = + graph.sdpa_fp8_backward(q, k, v, o, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, + scale_dP, sdpa_fp8_backwards_options); + + dQ->set_output(true) + .set_dim(dq_desc.dimensions()) + .set_stride(dq_desc.GetLogicalStrides()) + .set_name("dQ") + .set_uid(next_uid()) + .set_data_type(ioDataType); + dK->set_output(true) + .set_dim(dk_desc.dimensions()) + .set_stride(dk_desc.GetLogicalStrides()) + .set_name("dK") + .set_uid(next_uid()) + .set_data_type(ioDataType); + dV->set_output(true) + .set_dim(dv_desc.dimensions()) + .set_stride(dv_desc.GetLogicalStrides()) + .set_name("dV") + .set_uid(next_uid()) + .set_data_type(ioDataType); + Amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + + CudnnGraph cudnnGraph(std::move(graph)); + TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + dnn_support, NumericOptions{/*require_determinism=*/false, + /*allow_tf32=*/true})); + TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + + if (VLOG_IS_ON(4)) { + VLOG(4) << "\b workspace size:" << cudnnGraph.Graph().get_workspace_size(); + VLOG(4) << "\b flash attention f8 operation backward graph: " << graph; + } + + return cudnnGraph; +#else + return absl::UnimplementedError( + "Cudnn flash attention only supported with Cudnn >= 9.0.0"); +#endif +} + absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, const dnn::MatmulTensorDescriptor& k_desc, diff --git a/xla/stream_executor/cuda/cuda_dnn.h b/xla/stream_executor/cuda/cuda_dnn.h index 78578a2aa9ab91..ee07c85d025d53 100644 --- a/xla/stream_executor/cuda/cuda_dnn.h +++ b/xla/stream_executor/cuda/cuda_dnn.h @@ -700,6 +700,19 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional dropout_rate, const dnn::FMHAMaskKind mask_type); +absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( + dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, + const dnn::MatmulTensorDescriptor& k_desc, + const dnn::MatmulTensorDescriptor& p_desc, + const dnn::MatmulTensorDescriptor& v_desc, + const dnn::MatmulTensorDescriptor& do_desc, + const dnn::TensorDescriptor& dq_desc, const dnn::TensorDescriptor& dk_desc, + const dnn::TensorDescriptor& dv_desc, + const std::optional bias_descriptor, + std::optional dropout_rate, std::optional seed, + double scale, bool use_dropout, bool use_bias, + const dnn::FMHAMaskKind mask_type, bool force_deterministic); + absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_descriptor, @@ -709,18 +722,15 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( const std::optional stats_descriptor, const float scale, const dnn::FMHAMaskKind mask_type); -absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( +absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, const dnn::MatmulTensorDescriptor& k_desc, const dnn::MatmulTensorDescriptor& p_desc, const dnn::MatmulTensorDescriptor& v_desc, const dnn::MatmulTensorDescriptor& do_desc, const dnn::TensorDescriptor& dq_desc, const dnn::TensorDescriptor& dk_desc, - const dnn::TensorDescriptor& dv_desc, - const std::optional bias_descriptor, - std::optional dropout_rate, std::optional seed, - double scale, bool use_dropout, bool use_bias, - const dnn::FMHAMaskKind mask_type, bool force_deterministic); + const dnn::TensorDescriptor& dv_desc, double scale, + dnn::FMHAMaskKind mask_type); } // namespace gpu } // namespace stream_executor