Skip to content

Commit

Permalink
bwd
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Aug 16, 2024
1 parent 5013343 commit 949191a
Show file tree
Hide file tree
Showing 6 changed files with 465 additions and 9 deletions.
10 changes: 9 additions & 1 deletion xla/service/gpu/cublas_cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo) {
return target == kCudnnfMHASoftmaxF8CallTarget;
}

bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
}
const auto& target = hlo.custom_call_target();
return target == kCudnnfMHASoftmaxBackwardF8CallTarget;
}

bool IsFwdCustomCallTofMHA(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
return false;
Expand Down Expand Up @@ -169,7 +177,7 @@ bool IsCustomCallTofMHA(const HloInstruction& hlo) {
}

bool IsCustomCallTofMHAF8(const HloInstruction& hlo) {
return IsFwdCustomCallTofMHAF8(hlo);
return IsFwdCustomCallTofMHAF8(hlo) || IsBwdCustomCallTofMHAF8(hlo);
}

bool IsCubDeviceRadixSort(const HloInstruction& hlo) {
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/cublas_cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ extern const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget;
// Backward calls
extern const absl::string_view kCudnnfMHASoftmaxBackwardF8CallTarget;
extern const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget;
extern const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget;
extern const absl::string_view
kCudnnfMHAScaleBiasSoftmaxDropoutBackwardCallTarget;
extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget;

bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsCustomCallTofMHAF8(const HloInstruction& hlo);
bool IsFwdCustomCallTofMHA(const HloInstruction& hlo);
bool IsBwdCustomCallTofMHA(const HloInstruction& hlo);
Expand Down
122 changes: 122 additions & 0 deletions xla/service/gpu/tests/gpu_fused_mha_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,123 @@ class FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8
EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref,
ErrorSpec{1e-2, 1e-2}));
}

void TestImpl_Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8() {
if (skip_reason_) GTEST_SKIP() << *skip_reason_;
if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) <
se::dnn::VersionInfo(9, 0, 0)) {
GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0.";
}
XlaBuilder builder(TestName());
// generate padding mask in cuDNN directly
// XLA pattern match does not support pattern matching padding mask
// so directly lower to custom call instead for reference
std::string hlo_string_ref = R"(
HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0})->bf16[1,1,256,128]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}
clip.33 {
Arg_2.36 = bf16[] parameter(2)
broadcast.39 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_2.36), dimensions={}
Arg_1.35 = bf16[] parameter(1)
broadcast.37 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_1.35), dimensions={}
Arg_0.34 = bf16[1,1,256,128]{3,2,1,0} parameter(0)
maximum.38 = bf16[1,1,256,128]{3,2,1,0} maximum(broadcast.37, Arg_0.34)
ROOT minimum.40 = bf16[1,1,256,128]{3,2,1,0} minimum(broadcast.39, maximum.38)
} // clip.33
ENTRY main.106 {
constant.99 = f32[] constant(1)
broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={}
Arg_0.1 = bf16[1,1,256,128]{3,2,1,0} parameter(0)
constant.6 = bf16[] constant(1)
broadcast.7 = bf16[1,1,256,128]{3,2,1,0} broadcast(constant.6), dimensions={}
divide.8 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_0.1, broadcast.7)
constant.5 = bf16[] constant(-448)
constant.4 = bf16[] constant(448)
call.17 = bf16[1,1,256,128]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33
convert.18 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.17)
convert.19 = bf16[1,1,256,128]{3,2,1,0} convert(convert.18)
Arg_1.2 = bf16[1,1,256,128]{3,2,1,0} parameter(1)
divide.20 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_1.2, broadcast.7)
call.29 = bf16[1,1,256,128]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33
convert.30 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.29)
convert.31 = bf16[1,1,256,128]{3,2,1,0} convert(convert.30)
Arg_2.3 = bf16[1,1,256,128]{3,2,1,0} parameter(2)
divide.32 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_2.3, broadcast.7)
call.41 = bf16[1,1,256,128]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33
convert.42 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.41)
convert.43 = bf16[1,1,256,128]{3,2,1,0} convert(convert.42)
Arg_3.4 = bf16[1,1,256,128]{3,2,1,0} parameter(3)
divide.72 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_3.4, broadcast.7)
call.71 = bf16[1,1,256,128]{3,2,1,0} call(divide.72, constant.5, constant.4), to_apply=clip.33
convert.72 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.71)
convert.73 = bf16[1,1,256,128]{3,2,1,0} convert(convert.72)
Arg_4.5 = bf16[1,1,256,128]{3,2,1,0} parameter(4)
divide.82 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_4.5, broadcast.7)
call.81 = bf16[1,1,256,128]{3,2,1,0} call(divide.82, constant.5, constant.4), to_apply=clip.33
convert.82 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.81)
convert.83 = bf16[1,1,256,128]{3,2,1,0} convert(convert.82)
bt.0 = f32[1,1,256]{2,1,0} broadcast(constant.99), dimensions={}
custom-call.7 = (bf16[1,1,256,128]{3,1,2,0}, bf16[1,1,256,128]{3,1,2,0}, bf16[1,1,256,128]{3,1,2,0}, u8[0]{0}) custom-call(convert.19, convert.31, convert.43, bt.0, convert.83, /*index=5*/convert.73), custom_call_target="__cudnn$fmhaSoftmaxBackward", operand_layout_constraints={bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, f32[1,1,256]{2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 1.0, "dropout_rate": 0.0, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["4", "4", "16", "16"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "seed": 42, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["2"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["2"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["2"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "1"]}}}
get-tuple-element.8 = bf16[1,1,256,128]{3,1,2,0} get-tuple-element(custom-call.7), index=0
get-tuple-element.9 = bf16[1,1,256,128]{3,1,2,0} get-tuple-element(custom-call.7), index=1
ROOT tuple.3 = (bf16[1,1,256,128]{3,1,2,0}, bf16[1,1,256,128]{3,1,2,0}) tuple(get-tuple-element.8, get-tuple-element.9)
})";

std::string hlo_string = R"(
HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0})->bf16[1,1,256,128]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}
clip.33 {
Arg_2.36 = bf16[] parameter(2)
broadcast.39 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_2.36), dimensions={}
Arg_1.35 = bf16[] parameter(1)
broadcast.37 = bf16[1,1,256,128]{3,2,1,0} broadcast(Arg_1.35), dimensions={}
Arg_0.34 = bf16[1,1,256,128]{3,2,1,0} parameter(0)
maximum.38 = bf16[1,1,256,128]{3,2,1,0} maximum(broadcast.37, Arg_0.34)
ROOT minimum.40 = bf16[1,1,256,128]{3,2,1,0} minimum(broadcast.39, maximum.38)
}
ENTRY main.106 {
constant.99 = f32[] constant(1)
broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={}
Arg_0.1 = bf16[1,1,256,128]{3,2,1,0} parameter(0)
constant.6 = bf16[] constant(1)
broadcast.7 = bf16[1,1,256,128]{3,2,1,0} broadcast(constant.6), dimensions={}
divide.8 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_0.1, broadcast.7)
constant.5 = bf16[] constant(-448)
constant.4 = bf16[] constant(448)
call.17 = bf16[1,1,256,128]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33
convert.18 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.17)
convert.19 = bf16[1,1,256,128]{3,2,1,0} convert(convert.18)
Arg_1.2 = bf16[1,1,256,128]{3,2,1,0} parameter(1)
divide.20 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_1.2, broadcast.7)
call.29 = bf16[1,1,256,128]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33
convert.30 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.29)
convert.31 = bf16[1,1,256,128]{3,2,1,0} convert(convert.30)
Arg_2.3 = bf16[1,1,256,128]{3,2,1,0} parameter(2)
divide.32 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_2.3, broadcast.7)
call.41 = bf16[1,1,256,128]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33
convert.42 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.41)
convert.43 = bf16[1,1,256,128]{3,2,1,0} convert(convert.42)
Arg_3.4 = bf16[1,1,256,128]{3,2,1,0} parameter(3)
divide.72 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_3.4, broadcast.7)
call.71 = bf16[1,1,256,128]{3,2,1,0} call(divide.72, constant.5, constant.4), to_apply=clip.33
convert.72 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.71)
convert.73 = bf16[1,1,256,128]{3,2,1,0} convert(convert.72)
Arg_4.5 = bf16[1,1,256,128]{3,2,1,0} parameter(4)
divide.82 = bf16[1,1,256,128]{3,2,1,0} divide(Arg_4.5, broadcast.7)
call.81 = bf16[1,1,256,128]{3,2,1,0} call(divide.82, constant.5, constant.4), to_apply=clip.33
convert.82 = f8e4m3fn[1,1,256,128]{3,2,1,0} convert(call.81)
convert.83 = bf16[1,1,256,128]{3,2,1,0} convert(convert.82)
bt.0 = f32[1,1,256]{2,1,0} broadcast(constant.99), dimensions={}
custom-call.9 = (f8e4m3fn[1,1,256,128]{3,1,2,0}, f8e4m3fn[1,1,256,128]{3,1,2,0}, f8e4m3fn[1,1,256,128]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, u8[0]{0}) custom-call(convert.18, convert.30, convert.42, convert.72, convert.82, bt.0, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99, broadcast.99), custom_call_target="__cudnn$fmhaSoftmaxBackward$f8"
get-tuple-element.10 = f8e4m3fn[1,1,256,128]{3,1,2,0} get-tuple-element(custom-call.9), index=0
ROOT out.10 = bf16[1,1,256,128]{3,1,2,0} convert(get-tuple-element.10)
})";

EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref,
ErrorSpec{1e-1, 1e-1}));
}
};

// BMM1 - Scale - CausalMask - Softmax - BMM2
Expand Down Expand Up @@ -1245,6 +1362,11 @@ XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8,
Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8) {
TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8();
}

XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8,
Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8) {
TestImpl_Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8();
}
} // namespace
} // namespace gpu
} // namespace xla
120 changes: 118 additions & 2 deletions xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
static_cast<float>(config.fmha_scale()), dropout_rate > 0.0,
dropout_rate, dnn_mask_type));
return std::move(graph);
} else if (IsFwdCustomCallTofMHAF8(*custom_call)) {
} else if (IsFwdCustomCallTofMHAF8(*custom_call)) {
TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind,
xla::gpu::GetCudnnfMHAKind(custom_call));
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -205,7 +205,7 @@ absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
rhs_bmm2, output, activation,
static_cast<float>(config.fmha_scale()), dnn_mask_type));
return std::move(graph);
} else {
} else if (IsBwdCustomCallTofMHA(*custom_call)) {
TF_ASSIGN_OR_RETURN(
auto gpu_config,
custom_call->backend_config<xla::gpu::GpuBackendConfig>());
Expand Down Expand Up @@ -322,6 +322,122 @@ absl::StatusOr<se::gpu::CudnnGraph> HloCustomCallToCuDnnGraph(
config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt,
dnn_mask_type, force_deterministic));
return std::move(graph);
} else {
TF_ASSIGN_OR_RETURN(
auto gpu_config,
custom_call->backend_config<xla::gpu::GpuBackendConfig>());
xla::gpu::CudnnfMHABackendConfig& config =
*gpu_config.mutable_cudnn_fmha_backend_config();

int input_index = 0;
Shape bmm1_grad_gemm1_rhs_shape =
custom_call->operand(input_index++)->shape();
Shape bmm1_grad_gemm2_rhs_shape =
custom_call->operand(input_index++)->shape();
Shape bmm2_grad_gemm2_rhs_shape =
custom_call->operand(input_index++)->shape();

Shape fwd_output_shape =
custom_call->operand(input_index++)->shape();
Shape d_output_shape = custom_call->operand(input_index++)->shape();

Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape());
input_index++;

TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind,
GetCudnnfMHAKind(custom_call));
std::cout << "xxxxxxxxxxxxxxx operand_count"
<< custom_call->operand_count();
std::cout << custom_call->ToString();
TF_RET_CHECK(input_index == 6);

int output_index = 0;
Shape d_bmm1_lhs_shape =
ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
Shape d_bmm1_rhs_shape =
ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
Shape d_bmm2_rhs_shape =
ShapeUtil::GetSubshape(custom_call->shape(), {output_index++});
std::cout << "\nyyyyy operand_count"
<< custom_call->shape().tuple_shapes().size()
<< "; and output_index=" << output_index;

TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm1_grad_gemm1_rhs,
MatmulTensorDescriptorFor(
bmm1_grad_gemm1_rhs_shape,
config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm1_grad_gemm2_rhs,
MatmulTensorDescriptorFor(
bmm1_grad_gemm2_rhs_shape,
config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm2_grad_gemm1_lhs,
MatmulTensorDescriptorFor(
bmm2_grad_gemm1_lhs_shape,
config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor bmm2_grad_gemm2_rhs,
MatmulTensorDescriptorFor(
bmm2_grad_gemm2_rhs_shape,
config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS));
TF_ASSIGN_OR_RETURN(
MatmulTensorDescriptor d_output,
MatmulTensorDescriptorFor(
d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(),
RHS));

TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs,
TensorDescriptorFor(d_bmm1_lhs_shape));
TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs,
TensorDescriptorFor(d_bmm1_rhs_shape));
TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs,
TensorDescriptorFor(d_bmm2_rhs_shape));

// The last one is the workspace.
TF_RET_CHECK(output_index == custom_call->shape().tuple_shapes().size() -
5); // 4 amax and a workspace


TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));

// GpufMHABackwardF8Descriptor descriptor = {
// kind,
// config,
// cudnn_mask_type,
// bmm1_grad_gemm1_rhs_shape,
// bmm1_grad_gemm2_rhs_shape,
// bmm2_grad_gemm1_lhs_shape,
// bmm2_grad_gemm2_rhs_shape,
// d_output_shape,
// d_bmm1_lhs_shape,
// d_bmm1_rhs_shape,
// d_bmm2_rhs_shape,
// config.bmm1_grad_gemm1_dot_dimension_numbers(),
// config.bmm1_grad_gemm2_dot_dimension_numbers(),
// config.bmm2_grad_gemm1_dot_dimension_numbers(),
// config.bmm2_grad_gemm2_dot_dimension_numbers(),
// fwd_output_shape};

// TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config,
// GpufMHABackwardConfig::For(descriptor));
TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
AsCudnnFmhaMaskKind(config.mask_type()));
TF_ASSIGN_OR_RETURN(
se::dnn::FMHAMaskKind dnn_mask_type,
GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type));
std::cout << "bf get cudnn graph\n";
TF_ASSIGN_OR_RETURN(
se::gpu::CudnnGraph graph,
se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph(
dnn_support, bmm1_grad_gemm1_rhs,
bmm1_grad_gemm2_rhs, bmm2_grad_gemm1_lhs,
bmm2_grad_gemm2_rhs, d_output,
d_bmm1_lhs, d_bmm1_rhs,
d_bmm2_rhs, config.fmha_scale(), dnn_mask_type));
std::cout << "end of wkspace rewrite\n";
return std::move(graph);
}
}

Expand Down
Loading

0 comments on commit 949191a

Please sign in to comment.