diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 5ef096475539e..ee7a1c12d29c5 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -369,7 +369,7 @@ absl::StatusOr HloCustomCallToCuDnnGraph( TensorDescriptorFor(d_bmm1_rhs_shape)); TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs, TensorDescriptorFor(d_bmm2_rhs_shape)); - // 4 gradients, 4 amaxs and one workspace + // 3 gradients, 4 amaxs and one workspace TF_RET_CHECK(8 == custom_call->shape().tuple_shapes().size()); TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config));