Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #15331: Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) #18100

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions xla/service/gpu/cublas_cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ bool IsCustomCallToDnnNorm(const HloInstruction& hlo) {
}

bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
}
const auto& target = hlo.custom_call_target();
return target == kCudnnfMHASoftmaxF8CallTarget;
return hlo.opcode() == HloOpcode::kCustomCall &&
hlo.custom_call_target() == kCudnnfMHASoftmaxF8CallTarget;
}

bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo) {
return hlo.opcode() == HloOpcode::kCustomCall &&
hlo.custom_call_target() == kCudnnfMHASoftmaxBackwardF8CallTarget;
}

bool IsFwdCustomCallTofMHA(const HloInstruction& hlo) {
Expand Down Expand Up @@ -169,7 +171,7 @@ bool IsCustomCallTofMHA(const HloInstruction& hlo) {
}

bool IsCustomCallTofMHAF8(const HloInstruction& hlo) {
return IsFwdCustomCallTofMHAF8(hlo);
return IsFwdCustomCallTofMHAF8(hlo) || IsBwdCustomCallTofMHAF8(hlo);
}

bool IsCubDeviceRadixSort(const HloInstruction& hlo) {
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/cublas_cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ extern const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget;
// Backward calls
extern const absl::string_view kCudnnfMHASoftmaxBackwardF8CallTarget;
extern const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget;
extern const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget;
extern const absl::string_view
kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget;

bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsFwdCustomCallTofMHA(const HloInstruction& hlo);
bool IsBwdCustomCallTofMHA(const HloInstruction& hlo);
Expand Down
258 changes: 258 additions & 0 deletions xla/service/gpu/tests/gpu_fused_mha_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1813,6 +1813,264 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxDropoutBMM,
Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2) {
TestImpl_Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2();
}

XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8,
Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8) {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
se::dnn::VersionInfo(9, 1, 0)) {
GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0.";
}
XlaBuilder builder(TestName());
std::string hlo_string_ref = R"(
HloModule fmha_cudnn_custom_call_bwd
// Process inputs: clip, convert to f8e4m3fn, and convert back to bf16
cast_to_representable {
// Parameters
input = bf16[1,1,128,128] parameter(0)
min_val = bf16[] parameter(1)
max_val = bf16[] parameter(2)

// Broadcasting min and max values
min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={}
max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={}

// Clipping the scaled input
clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input)
clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min)

// Converting to f8e4m3fn and back to bf16
converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped)
ROOT converted_bf16 = bf16[1,1,128,128] convert(converted_f8)
}
// Main function
ENTRY main {
// Input parameters
query = bf16[1,1,128,128] parameter(0)
key = bf16[1,1,128,128] parameter(1)
value = bf16[1,1,128,128] parameter(2)
grad_output = bf16[1,1,128,128] parameter(3)
fwd_output = bf16[1,1,128,128] parameter(4)
score = f32[1,1,128] parameter(5)

// Constants
one_f32 = f32[] constant(1)
one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={}
min_clip_val = bf16[] constant(-448)
max_clip_val = bf16[] constant(448)

query_processed = bf16[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable
key_processed = bf16[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable
value_processed = bf16[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable
grad_output_processed = bf16[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable
fwd_output_processed = bf16[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable

// FMHA Forward Backward custom call
fmha_result = (bf16[1,1,128,128], bf16[1,1,128,128], bf16[1,1,128,128], u8[0]) custom-call(
query_processed, key_processed, value_processed,
score, fwd_output_processed, grad_output_processed
),
custom_call_target="__cudnn$fmhaSoftmaxBackward",
operand_layout_constraints={
bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0},
bf16[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0},
bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0}
},
api_version=API_VERSION_STATUS_RETURNING,
backend_config={
"operation_queue_id": "0",
"wait_on_operation_queues": [],
"cudnn_fmha_backend_config": {
"algorithm": {
"algo_id": "0",
"math_type": "TENSOR_OP_MATH",
"tuning_knobs": {"17": "1", "24": "0"},
"is_cudnn_frontend": true,
"workspace_size": "0"
},
"fmha_scale": 1.0,
"dropout_rate": 0.0,
"intermediate_tensor_shape": {
"element_type": "BF16",
"dimensions": ["1", "1", "128", "128"],
"tuple_shapes": [],
"layout": {
"dim_level_types": [],
"dim_unique": [],
"dim_ordered": [],
"minor_to_major": ["3", "2", "1", "0"],
"tiles": [],
"element_size_in_bits": "0",
"memory_space": "0",
"index_primitive_type": "PRIMITIVE_TYPE_INVALID",
"pointer_primitive_type": "PRIMITIVE_TYPE_INVALID",
"dynamic_shape_metadata_prefix_bytes": "0"
},
"is_dynamic_dimension": [false, false, false, false]
},
"seed": 42,
"is_flash_attention": true,
"mask_type": "NO_MASK",
"sliding_window_length": 0,
"bmm1_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm1_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["3"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
}
}
}

ROOT output = bf16[1,1,128,128] get-tuple-element(fmha_result), index=0
})";

std::string hlo_string = R"(
HloModule fmha_cudnn_custom_call_bwd_f8
// Process inputs: clip, convert to f8e4m3fn
cast_to_representable {
// Parameters
input = bf16[1,1,128,128] parameter(0)
min_val = bf16[] parameter(1)
max_val = bf16[] parameter(2)

// Broadcasting min and max values
min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={}
max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={}

// Clipping the scaled input
clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input)
clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min)

// Converting to f8e4m3fn and back to bf16
ROOT converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped)
}

// Main function
ENTRY main {
// Input parameters
query = bf16[1,1,128,128] parameter(0)
key = bf16[1,1,128,128] parameter(1)
value = bf16[1,1,128,128] parameter(2)
grad_output = bf16[1,1,128,128] parameter(3)
fwd_output = bf16[1,1,128,128] parameter(4)
score = f32[1,1,128] parameter(5)

// Constants
one_f32 = f32[] constant(1)
one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={}
min_clip_val = bf16[] constant(-448)
max_clip_val = bf16[] constant(448)

query_processed = f8e4m3fn[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable
key_processed = f8e4m3fn[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable
value_processed = f8e4m3fn[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable
grad_output_processed = f8e4m3fn[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable
fwd_output_processed = f8e4m3fn[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable

// FMHA Softmax Backward custom call
fmha_result = (f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128],
f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], u8[0]) custom-call(
query_processed, key_processed, value_processed,
grad_output_processed, fwd_output_processed, score,
one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast,
one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast,
one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast
),
custom_call_target="__cudnn$fmhaSoftmaxBackwardF8",
operand_layout_constraints={
f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0},
f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0},
f8e4m3fn[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0},
f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}
},
api_version=API_VERSION_STATUS_RETURNING,
backend_config={
"operation_queue_id": "0",
"wait_on_operation_queues": [],
"cudnn_fmha_backend_config": {
"algorithm": {
"algo_id": "0",
"math_type": "TENSOR_OP_MATH",
"tuning_knobs": {"17": "1", "24": "0"},
"is_cudnn_frontend": true,
"workspace_size": "0"
},
"fmha_scale": 1.0,
"intermediate_tensor_shape": {
"element_type": "BF16",
"dimensions": ["1", "1", "128", "128"],
"tuple_shapes": [],
"layout": {
"dim_level_types": [],
"dim_unique": [],
"dim_ordered": [],
"minor_to_major": ["3", "2", "1", "0"],
"tiles": [],
"element_size_in_bits": "0",
"memory_space": "0",
"index_primitive_type": "PRIMITIVE_TYPE_INVALID",
"pointer_primitive_type": "PRIMITIVE_TYPE_INVALID",
"dynamic_shape_metadata_prefix_bytes": "0"
},
"is_dynamic_dimension": [false, false, false, false]
},
"is_flash_attention": true,
"mask_type": "NO_MASK",
"bmm1_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm1_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm1_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["2"],
"rhs_contracting_dimensions": ["2"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
},
"bmm2_grad_gemm2_dot_dimension_numbers": {
"lhs_contracting_dimensions": ["3"],
"rhs_contracting_dimensions": ["3"],
"lhs_batch_dimensions": ["0", "1"],
"rhs_batch_dimensions": ["0", "1"]
}
}
}

fmha_output = f8e4m3fn[1,1,128,128] get-tuple-element(fmha_result), index=0
ROOT output = bf16[1,1,128,128] convert(fmha_output)
})";

EXPECT_TRUE(RunAndCompareTwoModules(hlo_string_ref, hlo_string,
ErrorSpec{2e-1, 2e-1}));
}
} // namespace
} // namespace gpu
} // namespace xla
71 changes: 70 additions & 1 deletion xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, activation,
static_cast<float>(config.fmha_scale()), dnn_mask_type));
return std::move(graph);
} else {
} else if (IsBwdCustomCallTofMHA(*custom_call)) {
TF_ASSIGN_OR_RETURN(
auto gpu_config,
custom_call->backend_config<xla::gpu::GpuBackendConfig>());
Expand Down Expand Up @@ -314,6 +314,75 @@ absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt,
dnn_mask_type, force_deterministic, sliding_window_length));
return std::move(graph);
} else {
TF_ASSIGN_OR_RETURN(
auto gpu_config,
custom_call->backend_config<xla::gpu::GpuBackendConfig>());
xla::gpu::CudnnfMHABackendConfig &config =
*gpu_config.mutable_cudnn_fmha_backend_config();

Shape bmm1_grad_gemm1_rhs_shape = custom_call->operand(0)->shape();
Shape bmm1_grad_gemm2_rhs_shape = custom_call->operand(1)->shape();
Shape bmm2_grad_gemm2_rhs_shape = custom_call->operand(2)->shape();

Shape fwd_output_shape = custom_call->operand(3)->shape();
Shape d_output_shape = custom_call->operand(4)->shape();

Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape());

Shape d_bmm1_lhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {0});
Shape d_bmm1_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {1});
Shape d_bmm2_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {2});

TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm1_grad_gemm1_rhs,
MatmulTensorDescriptorFor(
bmm1_grad_gemm1_rhs_shape,
config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm1_grad_gemm2_rhs,
MatmulTensorDescriptorFor(
bmm1_grad_gemm2_rhs_shape,
config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm2_grad_gemm1_lhs,
MatmulTensorDescriptorFor(
bmm2_grad_gemm1_lhs_shape,
config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm2_grad_gemm2_rhs,
MatmulTensorDescriptorFor(
bmm2_grad_gemm2_rhs_shape,
config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor d_output,
MatmulTensorDescriptorFor(
d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(),
RHS));

TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs,
TensorDescriptorFor(d_bmm1_lhs_shape));
TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs,
TensorDescriptorFor(d_bmm1_rhs_shape));
TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs,
TensorDescriptorFor(d_bmm2_rhs_shape));
// 3 gradients, 4 amaxs and one workspace
TF_RET_CHECK(8 == custom_call->shape().tuple_shapes().size());

TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));

TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
AsCudnnFmhaMaskKind(config.mask_type()));
TF_ASSIGN_OR_RETURN(
se::dnn::FMHAMaskKind dnn_mask_type,
GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type));
TF_ASSIGN_OR_RETURN(
se::gpu::CudnnGraph graph,
se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph(
dnn_support, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs,
bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs,
d_bmm1_rhs, d_bmm2_rhs, config.fmha_scale(), dnn_mask_type));
return std::move(graph);
}
}

Expand Down
Loading
Loading