Skip to content

Commit

Permalink
PR #15331: Support cuDNN frontend scaled dot product attention for FP…
Browse files Browse the repository at this point in the history
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 684025541
  • Loading branch information
wenscarl authored and Google-ML-Automation committed Oct 9, 2024
1 parent 23c0e07 commit fab1b65
Show file tree
Hide file tree
Showing 6 changed files with 544 additions and 7 deletions.
14 changes: 8 additions & 6 deletions xla/service/gpu/cublas_cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ bool IsCustomCallToDnnNorm(const HloInstruction& hlo) {
}

bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
}
const auto& target = hlo.custom_call_target();
return target == kCudnnfMHASoftmaxF8CallTarget;
return hlo.opcode() == HloOpcode::kCustomCall &&
hlo.custom_call_target() == kCudnnfMHASoftmaxF8CallTarget;
}

bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo) {
return hlo.opcode() == HloOpcode::kCustomCall &&
hlo.custom_call_target() == kCudnnfMHASoftmaxBackwardF8CallTarget;
}

bool IsFwdCustomCallTofMHA(const HloInstruction& hlo) {
Expand Down Expand Up @@ -169,7 +171,7 @@ bool IsCustomCallTofMHA(const HloInstruction& hlo) {
}

bool IsCustomCallTofMHAF8(const HloInstruction& hlo) {
return IsFwdCustomCallTofMHAF8(hlo);
return IsFwdCustomCallTofMHAF8(hlo) || IsBwdCustomCallTofMHAF8(hlo);
}

bool IsCubDeviceRadixSort(const HloInstruction& hlo) {
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/cublas_cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ extern const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget;
// Backward calls
extern const absl::string_view kCudnnfMHASoftmaxBackwardF8CallTarget;
extern const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget;
extern const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget;
extern const absl::string_view
kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget;

bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsFwdCustomCallTofMHA(const HloInstruction& hlo);
bool IsBwdCustomCallTofMHA(const HloInstruction& hlo);
Expand Down
258 changes: 258 additions & 0 deletions xla/service/gpu/tests/gpu_fused_mha_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1813,6 +1813,264 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxDropoutBMM,
Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2) {
TestImpl_Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2();
}

XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8,
Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8) {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
se::dnn::VersionInfo(9, 1, 0)) {
GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0.";
}
XlaBuilder builder(TestName());
std::string hlo_string_ref = R"(
HloModule fmha_cudnn_custom_call_bwd
// Process inputs: clip, convert to f8e4m3fn, and convert back to bf16
cast_to_representable {
// Parameters
input = bf16[1,1,128,128] parameter(0)
min_val = bf16[] parameter(1)
max_val = bf16[] parameter(2)
// Broadcasting min and max values
min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={}
max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={}
// Clipping the scaled input
clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input)
clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min)
// Converting to f8e4m3fn and back to bf16
converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped)
ROOT converted_bf16 = bf16[1,1,128,128] convert(converted_f8)
}
// Main function
ENTRY main {
// Input parameters
query = bf16[1,1,128,128] parameter(0)
key = bf16[1,1,128,128] parameter(1)
value = bf16[1,1,128,128] parameter(2)
grad_output = bf16[1,1,128,128] parameter(3)
fwd_output = bf16[1,1,128,128] parameter(4)
score = f32[1,1,128] parameter(5)
// Constants
one_f32 = f32[] constant(1)
one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={}
min_clip_val = bf16[] constant(-448)
max_clip_val = bf16[] constant(448)
query_processed = bf16[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable
key_processed = bf16[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable
value_processed = bf16[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable
grad_output_processed = bf16[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable
fwd_output_processed = bf16[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable
// FMHA Forward Backward custom call
fmha_result = (bf16[1,1,128,128], bf16[1,1,128,128], bf16[1,1,128,128], u8[0]) custom-call(
query_processed, key_processed, value_processed,
score, fwd_output_processed, grad_output_processed
),
custom_call_target="__cudnn$fmhaSoftmaxBackward",
operand_layout_constraints={
bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0},
bf16[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0},
bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0}
},
api_version=API_VERSION_STATUS_RETURNING,
backend_config={
"operation_queue_id": "0",
"wait_on_operation_queues": [],
"cudnn_fmha_backend_config": {
"algorithm": {
"algo_id": "0",
"math_type": "TENSOR_OP_MATH",
"tuning_knobs": {"17": "1", "24": "0"},
"is_cudnn_frontend": true,
"workspace_size": "0"
},
"fmha_scale": 1.0,
"dropout_rate": 0.0,
"intermediate_tensor_shape": {
"element_type": "BF16",
"dimensions": ["1", "1", "128", "128"],
"tuple_shapes": [],
"layout": {
"dim_level_types": [],
"dim_unique": [],
"dim_ordered": [],
"minor_to_major": ["3", "2", "1", "0"],
"tiles": [],
"element_size_in_bits": "0",
"memory_space": "0",
"index_primitive_type": "PRIMITIVE_TYPE_INVALID",
"pointer_primitive_type": "PRIMITIVE_TYPE_INVALID",
"dynamic_shape_metadata_prefix_bytes": "0"
},
"is_dynamic_dimension": [false, false, false, false]
},
"seed": 42,
"is_flash_attention": true,
"mask_type": "NO_MASK",
"sliding_window_length": 0,
"bmm1_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm1_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["3"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
}
}
}
ROOT output = bf16[1,1,128,128] get-tuple-element(fmha_result), index=0
})";

std::string hlo_string = R"(
HloModule fmha_cudnn_custom_call_bwd_f8
// Process inputs: clip, convert to f8e4m3fn
cast_to_representable {
// Parameters
input = bf16[1,1,128,128] parameter(0)
min_val = bf16[] parameter(1)
max_val = bf16[] parameter(2)
// Broadcasting min and max values
min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={}
max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={}
// Clipping the scaled input
clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input)
clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min)
// Converting to f8e4m3fn and back to bf16
ROOT converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped)
}
// Main function
ENTRY main {
// Input parameters
query = bf16[1,1,128,128] parameter(0)
key = bf16[1,1,128,128] parameter(1)
value = bf16[1,1,128,128] parameter(2)
grad_output = bf16[1,1,128,128] parameter(3)
fwd_output = bf16[1,1,128,128] parameter(4)
score = f32[1,1,128] parameter(5)
// Constants
one_f32 = f32[] constant(1)
one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={}
min_clip_val = bf16[] constant(-448)
max_clip_val = bf16[] constant(448)
query_processed = f8e4m3fn[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable
key_processed = f8e4m3fn[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable
value_processed = f8e4m3fn[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable
grad_output_processed = f8e4m3fn[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable
fwd_output_processed = f8e4m3fn[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable
// FMHA Softmax Backward custom call
fmha_result = (f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128],
f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], u8[0]) custom-call(
query_processed, key_processed, value_processed,
grad_output_processed, fwd_output_processed, score,
one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast,
one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast,
one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast
),
custom_call_target="__cudnn$fmhaSoftmaxBackwardF8",
operand_layout_constraints={
f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0},
f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0},
f8e4m3fn[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}
},
api_version=API_VERSION_STATUS_RETURNING,
backend_config={
"operation_queue_id": "0",
"wait_on_operation_queues": [],
"cudnn_fmha_backend_config": {
"algorithm": {
"algo_id": "0",
"math_type": "TENSOR_OP_MATH",
"tuning_knobs": {"17": "1", "24": "0"},
"is_cudnn_frontend": true,
"workspace_size": "0"
},
"fmha_scale": 1.0,
"intermediate_tensor_shape": {
"element_type": "BF16",
"dimensions": ["1", "1", "128", "128"],
"tuple_shapes": [],
"layout": {
"dim_level_types": [],
"dim_unique": [],
"dim_ordered": [],
"minor_to_major": ["3", "2", "1", "0"],
"tiles": [],
"element_size_in_bits": "0",
"memory_space": "0",
"index_primitive_type": "PRIMITIVE_TYPE_INVALID",
"pointer_primitive_type": "PRIMITIVE_TYPE_INVALID",
"dynamic_shape_metadata_prefix_bytes": "0"
},
"is_dynamic_dimension": [false, false, false, false]
},
"is_flash_attention": true,
"mask_type": "NO_MASK",
"bmm1_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm1_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["3"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
}
}
}
fmha_output = f8e4m3fn[1,1,128,128] get-tuple-element(fmha_result), index=0
ROOT output = bf16[1,1,128,128] convert(fmha_output)
})";

EXPECT_TRUE(RunAndCompareTwoModules(hlo_string_ref, hlo_string,
ErrorSpec{2e-1, 2e-1}));
}
} // namespace
} // namespace gpu
} // namespace xla
71 changes: 70 additions & 1 deletion xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, activation,
static_cast<float>(config.fmha_scale()), dnn_mask_type));
return std::move(graph);
} else {
} else if (IsBwdCustomCallTofMHA(*custom_call)) {
TF_ASSIGN_OR_RETURN(
auto gpu_config,
custom_call->backend_config<xla::gpu::GpuBackendConfig>());
Expand Down Expand Up @@ -314,6 +314,75 @@ absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt,
dnn_mask_type, force_deterministic, sliding_window_length));
return std::move(graph);
} else {
TF_ASSIGN_OR_RETURN(
auto gpu_config,
custom_call->backend_config<xla::gpu::GpuBackendConfig>());
xla::gpu::CudnnfMHABackendConfig &config =
*gpu_config.mutable_cudnn_fmha_backend_config();

Shape bmm1_grad_gemm1_rhs_shape = custom_call->operand(0)->shape();
Shape bmm1_grad_gemm2_rhs_shape = custom_call->operand(1)->shape();
Shape bmm2_grad_gemm2_rhs_shape = custom_call->operand(2)->shape();

Shape fwd_output_shape = custom_call->operand(3)->shape();
Shape d_output_shape = custom_call->operand(4)->shape();

Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape());

Shape d_bmm1_lhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {0});
Shape d_bmm1_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {1});
Shape d_bmm2_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {2});

TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm1_grad_gemm1_rhs,
MatmulTensorDescriptorFor(
bmm1_grad_gemm1_rhs_shape,
config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm1_grad_gemm2_rhs,
MatmulTensorDescriptorFor(
bmm1_grad_gemm2_rhs_shape,
config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm2_grad_gemm1_lhs,
MatmulTensorDescriptorFor(
bmm2_grad_gemm1_lhs_shape,
config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm2_grad_gemm2_rhs,
MatmulTensorDescriptorFor(
bmm2_grad_gemm2_rhs_shape,
config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor d_output,
MatmulTensorDescriptorFor(
d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(),
RHS));

TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs,
TensorDescriptorFor(d_bmm1_lhs_shape));
TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs,
TensorDescriptorFor(d_bmm1_rhs_shape));
TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs,
TensorDescriptorFor(d_bmm2_rhs_shape));
// 3 gradients, 4 amaxs and one workspace
TF_RET_CHECK(8 == custom_call->shape().tuple_shapes().size());

TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));

TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
AsCudnnFmhaMaskKind(config.mask_type()));
TF_ASSIGN_OR_RETURN(
se::dnn::FMHAMaskKind dnn_mask_type,
GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type));
TF_ASSIGN_OR_RETURN(
se::gpu::CudnnGraph graph,
se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph(
dnn_support, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs,
bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs,
d_bmm1_rhs, d_bmm2_rhs, config.fmha_scale(), dnn_mask_type));
return std::move(graph);
}
}

Expand Down
Loading

0 comments on commit fab1b65

Please sign in to comment.