From 18ff9bad70d6c4f42beed2e253ff8196c5de6e8b Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 8 Aug 2024 18:09:49 -0700 Subject: [PATCH] Break up cutlass_extensions.cu, pt 2 Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/57 - Break up cutlass_extensions.cu, pt 2 Differential Revision: D60942612 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 3 + .../gen_ai/src/quantize/cutlass_extensions.cu | 874 +----------------- .../cutlass_extensions/f8f8bf16_blockwise.cu | 288 ++++++ .../cutlass_extensions/f8f8bf16_cublas.cu | 180 ++++ .../cutlass_extensions/f8f8bf16_rowwise.cu | 494 ++++++++++ .../quantize/cutlass_extensions/i8i8bf16.cu | 1 + .../cutlass_extensions/include/threadblock.h | 36 +- 7 files changed, 976 insertions(+), 900 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_cublas.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 18796aaa47..bd22dc0030 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -32,6 +32,9 @@ set(attention_ops_sources set(quantize_ops_sources src/quantize/cutlass_extensions.cu + src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu + src/quantize/cutlass_extensions/f8f8bf16_cublas.cu + src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu src/quantize/cutlass_extensions/i8i8bf16.cu src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu src/quantize/quantize.cu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu index dda8c0184f..5ab681f019 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu @@ -57,14 +57,6 @@ #include "cutlass_extensions/include/threadblock.h" #include "fp8_blockwise_cutlass_helpers.h" -namespace { - -int64_t ceil_div(int64_t a, int64_t b) { - return (a + b - 1) / b; -} - -} // namespace - namespace fbgemm_gpu { #if CUDART_VERSION >= 12000 @@ -486,691 +478,6 @@ at::Tensor f8f8bf16_tensorwise( } } -// Cutlass rowwise kernel -template < - int TB_M, - int TB_N, - int TB_K, - int TBS_M, - int TBS_N, - int TBS_K, - bool PONG, - bool FAST_ACCUM, - bool USE_BIAS, - typename INPUT_DTYPE, - typename BIAS_DTYPE> -at::Tensor f8f8bf16_rowwise_impl( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, - at::Tensor w_scale, - std::optional bias, - std::optional output) { - // XQ: M x K - // WQ: N x K - // output: M x N - int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); - int N = WQ.size(0); - int K = WQ.size(1); - TORCH_CHECK(XQ.size(-1) == K); - // 1. If the input tensor is {M, K}, the output tensor is {M, N}. - // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. - auto out_sizes = XQ.sizes().vec(); - out_sizes.back() = N; - - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); - - at::Tensor Y; - if (output.has_value()) { - Y = output.value(); - // Make sure the provided output has the proper shape and dtype. - TORCH_CHECK(Y.sizes().vec() == out_sizes); - TORCH_CHECK(Y.dtype() == at::kBFloat16); - } else { - Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); - } - - using ElementInputA = INPUT_DTYPE; - using LayoutInputA = cutlass::layout::RowMajor; - constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); - - using ElementInputB = cutlass::float_e4m3_t; - using LayoutInputB = cutlass::layout::ColumnMajor; - constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); - - using ElementBias = BIAS_DTYPE; - - using ElementOutput = cutlass::bfloat16_t; - using LayoutOutput = cutlass::layout::RowMajor; - constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); - - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Threadblock-level - // tile size - using ClusterShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Shape of the - // threadblocks in a - // cluster - using StageCountType = - cutlass::gemm::collective::StageCountAuto; // Stage count maximized - // based on the tile size - using KernelSchedule = cutlass::gemm::collective:: - KernelScheduleAuto; // Kernel to launch based on the default setting in - // the Collective Builder - - // Implement rowwise scaling epilogue. - using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0, - TileShape, - ElementComputeEpilogue, - cute::Stride, cute::Int<0>, cute::Int<0>>>; - - using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - ElementComputeEpilogue, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< - PONG ? 2 : 1, - TileShape, - ElementBias, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - ElementComputeEpilogue, // First stage output type. - ElementComputeEpilogue, // First stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, - cute::conditional_t< // Second stage output type. - USE_BIAS, - ElementBias, - ElementOutput>, - ElementComputeEpilogue, // Second stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute1 = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< - cutlass::plus, - ElementOutput, // Final (optional) stage output type. - ElementBias, // Final stage input types. - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeBias = - cutlass::epilogue::fusion::Sm90EVT; - - using EpilogueEVT = - cute::conditional_t; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, - cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementComputeEpilogue, - ElementOutput, - LayoutOutput, - AlignmentOutput, - ElementOutput, - LayoutOutput, - AlignmentOutput, - cutlass::epilogue::TmaWarpSpecialized, - EpilogueEVT>::CollectiveOp; - - using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; - using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using FastDefaultSchedule = - cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; - using FastPongSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using SlowAccum = cute::conditional_t; - using FastAccum = - cute::conditional_t; - using MainLoopSchedule = - cute::conditional_t; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementInputA, - LayoutInputA, - AlignmentInputA, - ElementInputB, - LayoutInputB, - AlignmentInputB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainLoopSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideInputA = typename Gemm::GemmKernel::StrideA; - using StrideInputB = typename Gemm::GemmKernel::StrideB; - using StrideOutput = typename Gemm::GemmKernel::StrideC; - - StrideInputA stride_a = cutlass::make_cute_packed_stride( - StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{})); - StrideInputB stride_b = cutlass::make_cute_packed_stride( - StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{})); - StrideOutput stride_output = cutlass::make_cute_packed_stride( - StrideOutput{}, cute::make_shape(M, N, cute::Int<1>{})); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K}, - {reinterpret_cast(XQ.data_ptr()), - stride_a, - reinterpret_cast(WQ.data_ptr()), - stride_b}, - {{}, // Epilogue thread we populate below. - (ElementOutput*)Y.data_ptr(), - stride_output, - (ElementOutput*)Y.data_ptr(), - stride_output}}; - - if constexpr (USE_BIAS) { - arguments.epilogue.thread = { - {reinterpret_cast(bias.value().data_ptr())}, // bias - // compute_1 - { - {reinterpret_cast( - x_scale.data_ptr())}, // x_scale - // compute_0 - { - {reinterpret_cast( - w_scale.data_ptr())}, // w_scale - {}, // Accumulator - {} // Multiplies - }, - {}, // Multiplies - }, - {}, // Plus - }; - } else { - arguments.epilogue.thread = { - {reinterpret_cast( - x_scale.data_ptr())}, // x_scale - // compute_0 - { - {reinterpret_cast( - w_scale.data_ptr())}, // w_scale - {}, // Accumulator - {} // Multiplies - }, - {}, // Multiplies - }; - } - - Gemm gemm; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm(at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return Y; -} - -// FP8 Rowwise Cutlass kernel dispatch. -template -at::Tensor dispatch_fp8_rowwise_kernel( - at::Tensor XQ, - at::Tensor WQ, - at::Tensor x_scale, - at::Tensor w_scale, - std::optional bias, - std::optional output) { - KernelMode kernel = get_kernel_mode(XQ, WQ); - if (kernel == KernelMode::Small) { - return f8f8bf16_rowwise_impl< - 64, - 128, - 128, - 2, - 1, - 1, - false, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); - } else if (kernel == KernelMode::Large) { - return f8f8bf16_rowwise_impl< - 128, - 128, - 128, - 2, - 1, - 1, - true, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); - } else { - return f8f8bf16_rowwise_impl< - 128, - 128, - 128, - 1, - 2, - 1, - false, - FastAccum, - UseBias, - InputDType, - BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); - } -} - -at::Tensor f8f8bf16_rowwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, // FP32 - at::Tensor w_scale, // FP32 - std::optional bias = c10::nullopt, - bool use_fast_accum = true, - std::optional output = c10::nullopt) { - // Check datatypes. - TORCH_CHECK( - x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, - "Scale tensors must be float32."); - if (bias.has_value()) { - TORCH_CHECK( - bias.value().dtype() == at::kFloat || - bias.value().dtype() == at::kBFloat16, - "Bias type must be bfloat16 or float32 if provided."); - } - bool use_bias = bias.has_value(); - bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; - - // Templatize based on input dtype. - bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; - - if (use_bias) { - if (bf16_bias) { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - true, - cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); - } - } - } else { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - true, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - true, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - true, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - true, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } - } - } - } else { - if (use_fast_accum) { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - true, - false, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - true, - false, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } - } else { - if (use_e5m2) { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e5m2_t, - false, - false, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } else { - return dispatch_fp8_rowwise_kernel< - cutlass::float_e4m3_t, - false, - false, - float>(XQ, WQ, x_scale, w_scale, bias, output); - } - } - } -} - -// Cutlass blockwise kernel -template < - int TB_M, - int TB_N, - int TB_K, - int TBS_M, - int TBS_N, - int TBS_K> -at::Tensor f8f8bf16_blockwise_impl( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, - at::Tensor w_scale, - int64_t block_m, - int64_t block_n, - int64_t block_k) { - // XQ: M x K - // WQ: N x K - // output: M x N - int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); - int N = WQ.size(0); - int K = WQ.size(1); - // 1. If the input tensor is {M, K}, the output tensor is {M, N}. - // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. - auto out_sizes = XQ.sizes().vec(); - out_sizes.back() = N; - - TORCH_CHECK(WQ.size(1) == K); - TORCH_CHECK(XQ.stride(-1) == 1); - TORCH_CHECK(WQ.stride(0) == K); - TORCH_CHECK(WQ.stride(1) == 1); - - TORCH_CHECK(block_m % TB_N == 0); - TORCH_CHECK(block_n % TB_M == 0); - TORCH_CHECK(block_k % TB_K == 0); - - TORCH_CHECK(x_scale.dim() == 2); - TORCH_CHECK(w_scale.dim() == 2); - TORCH_CHECK(x_scale.size(0) == ceil_div(M, block_m)); - TORCH_CHECK(x_scale.size(1) == ceil_div(K, block_k)); - TORCH_CHECK(w_scale.size(0) == ceil_div(N, block_n)); - TORCH_CHECK(w_scale.size(1) == ceil_div(K, block_k)); - TORCH_CHECK(x_scale.stride(0) == ceil_div(K, block_k)); - TORCH_CHECK(x_scale.stride(1) == 1); - TORCH_CHECK(w_scale.stride(0) == ceil_div(K, block_k)); - TORCH_CHECK(w_scale.stride(1) == 1); - - TORCH_CHECK(XQ.dtype() == at::kFloat8_e4m3fn); - TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn); - TORCH_CHECK(XQ.is_cuda()); - TORCH_CHECK(WQ.is_cuda()); - TORCH_CHECK(XQ.device().index() == WQ.device().index()); - TORCH_CHECK(x_scale.dtype() == at::kFloat); - TORCH_CHECK(w_scale.dtype() == at::kFloat); - TORCH_CHECK(x_scale.is_cuda()); - TORCH_CHECK(w_scale.is_cuda()); - TORCH_CHECK(x_scale.device().index() == XQ.device().index()); - TORCH_CHECK(w_scale.device().index() == XQ.device().index()); - - auto Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); - - using ElementInputA = cutlass::float_e4m3_t; - using LayoutInputA = cutlass::layout::RowMajor; - constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); - - using ElementInputB = cutlass::float_e4m3_t; - using LayoutInputB = cutlass::layout::ColumnMajor; - constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); - - using ElementOutput = cutlass::bfloat16_t; - using LayoutOutput = cutlass::layout::ColumnMajor; - constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); - - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Threadblock-level - // tile size - using ClusterShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Shape of the - // threadblocks in a - // cluster - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementComputeEpilogue, - ElementOutput, - LayoutOutput, - AlignmentOutput, - ElementOutput, - LayoutOutput, - AlignmentOutput, - cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp; - - using MainLoopSchedule = - cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaling; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementInputA, - LayoutInputA, - AlignmentInputA, - ElementInputB, - LayoutInputB, - AlignmentInputB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainLoopSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideInputA = typename Gemm::GemmKernel::StrideA; - using StrideInputB = typename Gemm::GemmKernel::StrideB; - using StrideOutput = typename Gemm::GemmKernel::StrideD; - - StrideInputA stride_a = cutlass::make_cute_packed_stride( - StrideInputA{}, cute::make_shape(N, K, cute::Int<1>{})); - StrideInputB stride_b = cutlass::make_cute_packed_stride( - StrideInputB{}, cute::make_shape(M, K, cute::Int<1>{})); - StrideOutput stride_output = cutlass::make_cute_packed_stride( - StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{})); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {N, M, K}, - {reinterpret_cast(WQ.data_ptr()), - stride_a, - reinterpret_cast(XQ.data_ptr()), - stride_b, - w_scale.data_ptr(), - x_scale.data_ptr(), - static_cast(block_n / TB_M), - static_cast(block_m / TB_N), - static_cast(block_k / TB_K)}, - {{}, - (cutlass::bfloat16_t*)Y.data_ptr(), - stride_output, - (cutlass::bfloat16_t*)Y.data_ptr(), - stride_output}, - }; - - Gemm gemm; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm(at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return Y; -} - -// FP8 blockwise Cutlass kernel dispatch. -at::Tensor dispatch_fp8_blockwise_kernel( - at::Tensor XQ, - at::Tensor WQ, - at::Tensor x_scale, - at::Tensor w_scale, - int64_t block_m, - int64_t block_n, - int64_t block_k) { - KernelMode kernel = get_kernel_mode(XQ, WQ); - if (kernel == KernelMode::Small) { - return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>( - XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); - } else if (kernel == KernelMode::Large) { - return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>( - XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); - } else { - return f8f8bf16_blockwise_impl<128, 128, 128, 1, 2, 1>( - XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); - } -} - -at::Tensor f8f8bf16_blockwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, // FP32 - at::Tensor w_scale, // FP32 - int64_t block_m = 256, - int64_t block_n = 256, - int64_t block_k = 256) { - // Check datatypes. - TORCH_CHECK( - x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, - "Scale tensors must be float32."); - - return dispatch_fp8_blockwise_kernel( - XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); -} - template < int TB_M, int TB_N, @@ -1734,161 +1041,8 @@ at::Tensor f8i4bf16_rowwise( } } -at::Tensor f8f8bf16_cublas( - at::Tensor A, // FP8 - at::Tensor B, // FP8 - std::optional Ainvs = c10::nullopt, - std::optional Binvs = c10::nullopt, - bool use_fast_accum = true, - std::optional output = c10::nullopt) { - auto m = A.size(0); - auto n = B.size(0); - auto k = A.size(1); - size_t workspaceSize = CUBLAS_WORKSPACE_SIZE; - const int8_t fastAccuMode = use_fast_accum ? 1 : 0; - - TORCH_CHECK(A.is_cuda() && A.is_contiguous()); - TORCH_CHECK(B.is_cuda() && B.is_contiguous()); - - cublasLtHandle_t ltHandle; - checkCublasStatus(cublasLtCreate(<Handle)); - auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - auto workspace = allocator.allocate(workspaceSize); - if (output.has_value()) { - auto output_tensor = output.value(); - TORCH_CHECK(output_tensor.is_cuda()); - TORCH_CHECK(output_tensor.is_contiguous()); - TORCH_CHECK( - output_tensor.numel() == m * n, - "output_tensor.numel=", - output_tensor.numel(), - ", m=", - m, - ", n=", - n); - TORCH_CHECK(output_tensor.options().dtype() == at::kBFloat16); - } - - const cudaDataType_t A_type = CUDA_R_8F_E4M3; - const cudaDataType_t B_type = CUDA_R_8F_E4M3; - const cudaDataType_t D_type = CUDA_R_16BF; - - float one = 1.0; - float zero = 0.0; - - cublasOperation_t transa = CUBLAS_OP_T; - cublasOperation_t transb = CUBLAS_OP_N; - - cublasLtMatmulDesc_t operationDesc = nullptr; - cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Ddesc = nullptr; - cublasLtMatmulPreference_t preference = nullptr; - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; - // Create matrix descriptors. Not setting any extra attributes. - - auto lda = k; - auto ldb = k; - auto ldd = n; - checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, A_type, k, m, lda)); - checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, B_type, k, n, ldb)); - checkCublasStatus(cublasLtMatrixLayoutCreate(&Ddesc, D_type, n, m, ldd)); - - checkCublasStatus( - cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); - checkCublasStatus(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); - checkCublasStatus(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); - - checkCublasStatus(cublasLtMatmulDescSetAttribute( - operationDesc, - CUBLASLT_MATMUL_DESC_FAST_ACCUM, - &fastAccuMode, - sizeof(fastAccuMode))); - - if (Ainvs.has_value()) { - const float* Ainvs_pt = Ainvs.value().data_ptr(); - checkCublasStatus(cublasLtMatmulDescSetAttribute( - operationDesc, - CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, - &Ainvs_pt, - sizeof(Ainvs_pt))); - } - - if (Binvs.has_value()) { - const float* Binvs_pt = Binvs.value().data_ptr(); - checkCublasStatus(cublasLtMatmulDescSetAttribute( - operationDesc, - CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, - &Binvs_pt, - sizeof(Binvs_pt))); - } - - checkCublasStatus(cublasLtMatmulDescSetAttribute( - operationDesc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epilogue, - sizeof(epilogue))); - - checkCublasStatus(cublasLtMatmulPreferenceCreate(&preference)); - - checkCublasStatus(cublasLtMatmulPreferenceSetAttribute( - preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspaceSize, - sizeof(workspaceSize))); - - checkCublasStatus(cublasLtMatmulAlgoGetHeuristic( - ltHandle, - operationDesc, - Bdesc, - Adesc, - Ddesc, - Ddesc, - preference, - 1, - &heuristicResult, - &returnedResults)); - - if (returnedResults == 0) - throw std::runtime_error("Unable to find any suitable algorithms"); - - // D = alpha * (A * B) + beta * C - // Warmup - auto Y = output.value_or(at::empty({m, n}, A.options().dtype(at::kBFloat16))); - checkCublasStatus(cublasLtMatmul( - ltHandle, - operationDesc, - static_cast(&one), /* alpha */ - B.data_ptr(), /* B */ - Bdesc, - A.data_ptr(), /* A */ - Adesc, - static_cast(&zero), /* beta */ - nullptr, /* C */ - Ddesc, - Y.data_ptr(), /* D */ - Ddesc, - &heuristicResult.algo, /* algo */ - workspace.mutable_get(), /* workspace */ - workspaceSize, - at::cuda::getCurrentCUDAStream())); /* stream */ - return Y; -} #else -at::Tensor f8f8bf16_cublas( - at::Tensor A, // FP8 - at::Tensor B, // FP8 - std::optional Ainvs = c10::nullopt, - std::optional Binvs = c10::nullopt, - bool use_fast_accum = true, - std::optional output = c10::nullopt) { - throw std::runtime_error( - "CUDA version is older than 12.0"); // requires CUDA>=12 -} + at::Tensor f8f8bf16( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -1897,6 +1051,7 @@ at::Tensor f8f8bf16( throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } + at::Tensor f8f8bf16_tensorwise( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -1905,6 +1060,7 @@ at::Tensor f8f8bf16_tensorwise( throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } + at::Tensor f8i4bf16_rowwise( at::Tensor XQ, // FP8 at::Tensor WQ, // INT4 @@ -1914,6 +1070,7 @@ at::Tensor f8i4bf16_rowwise( throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } + at::Tensor bf16i4bf16_rowwise( at::Tensor X, // BF16 at::Tensor WQ, // INT4 @@ -1922,28 +1079,7 @@ at::Tensor bf16i4bf16_rowwise( throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } -at::Tensor f8f8bf16_rowwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, - at::Tensor w_scale, - std::optional bias = c10::nullopt, - bool use_fast_accum = true, - std::optional output = c10::nullopt) { - throw std::runtime_error( - "CUDA version is older than 12.0"); // requires CUDA>=12 -} -at::Tensor f8f8bf16_blockwise( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor x_scale, - at::Tensor w_scale, - int64_t block_m = 256, - int64_t block_n = 256, - int64_t block_k = 256) { - throw std::runtime_error( - "CUDA version is older than 12.0"); // requires CUDA>=12 -} + #endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu new file mode 100644 index 0000000000..3cfccbe175 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu @@ -0,0 +1,288 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" +#include "fp8_blockwise_cutlass_helpers.h" + +namespace { + +int64_t ceil_div(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +} // namespace + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +// Cutlass blockwise kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K> +at::Tensor f8f8bf16_blockwise_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + int64_t block_m, + int64_t block_n, + int64_t block_k) { + // XQ: M x K + // WQ: N x K + // output: M x N + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + // 1. If the input tensor is {M, K}, the output tensor is {M, N}. + // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. + auto out_sizes = XQ.sizes().vec(); + out_sizes.back() = N; + + TORCH_CHECK(WQ.size(1) == K); + TORCH_CHECK(XQ.stride(-1) == 1); + TORCH_CHECK(WQ.stride(0) == K); + TORCH_CHECK(WQ.stride(1) == 1); + + TORCH_CHECK(block_m % TB_N == 0); + TORCH_CHECK(block_n % TB_M == 0); + TORCH_CHECK(block_k % TB_K == 0); + + TORCH_CHECK(x_scale.dim() == 2); + TORCH_CHECK(w_scale.dim() == 2); + TORCH_CHECK(x_scale.size(0) == ceil_div(M, block_m)); + TORCH_CHECK(x_scale.size(1) == ceil_div(K, block_k)); + TORCH_CHECK(w_scale.size(0) == ceil_div(N, block_n)); + TORCH_CHECK(w_scale.size(1) == ceil_div(K, block_k)); + TORCH_CHECK(x_scale.stride(0) == ceil_div(K, block_k)); + TORCH_CHECK(x_scale.stride(1) == 1); + TORCH_CHECK(w_scale.stride(0) == ceil_div(K, block_k)); + TORCH_CHECK(w_scale.stride(1) == 1); + + TORCH_CHECK(XQ.dtype() == at::kFloat8_e4m3fn); + TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn); + TORCH_CHECK(XQ.is_cuda()); + TORCH_CHECK(WQ.is_cuda()); + TORCH_CHECK(XQ.device().index() == WQ.device().index()); + TORCH_CHECK(x_scale.dtype() == at::kFloat); + TORCH_CHECK(w_scale.dtype() == at::kFloat); + TORCH_CHECK(x_scale.is_cuda()); + TORCH_CHECK(w_scale.is_cuda()); + TORCH_CHECK(x_scale.device().index() == XQ.device().index()); + TORCH_CHECK(w_scale.device().index() == XQ.device().index()); + + auto Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); + + using ElementInputA = cutlass::float_e4m3_t; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp; + + using MainLoopSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaling; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideD; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(N, K, cute::Int<1>{})); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(M, K, cute::Int<1>{})); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{})); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K}, + {reinterpret_cast(WQ.data_ptr()), + stride_a, + reinterpret_cast(XQ.data_ptr()), + stride_b, + w_scale.data_ptr(), + x_scale.data_ptr(), + static_cast(block_n / TB_M), + static_cast(block_m / TB_N), + static_cast(block_k / TB_K)}, + {{}, + (cutlass::bfloat16_t*)Y.data_ptr(), + stride_output, + (cutlass::bfloat16_t*)Y.data_ptr(), + stride_output}, + }; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +// FP8 blockwise Cutlass kernel dispatch. +at::Tensor dispatch_fp8_blockwise_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + int64_t block_m, + int64_t block_n, + int64_t block_k) { + KernelMode kernel = get_kernel_mode(XQ, WQ); + if (kernel == KernelMode::Small) { + return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>( + XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>( + XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); + } else { + return f8f8bf16_blockwise_impl<128, 128, 128, 1, 2, 1>( + XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); + } +} + +at::Tensor f8f8bf16_blockwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + int64_t block_m = 256, + int64_t block_n = 256, + int64_t block_k = 256) { + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + + return dispatch_fp8_blockwise_kernel( + XQ, WQ, x_scale, w_scale, block_m, block_n, block_k); +} + +#else + +at::Tensor f8f8bf16_blockwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + int64_t block_m = 256, + int64_t block_n = 256, + int64_t block_k = 256) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_cublas.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_cublas.cu new file mode 100644 index 0000000000..04979323e9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_cublas.cu @@ -0,0 +1,180 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "cublas_utils.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +at::Tensor f8f8bf16_cublas( + at::Tensor A, // FP8 + at::Tensor B, // FP8 + std::optional Ainvs = c10::nullopt, + std::optional Binvs = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + auto m = A.size(0); + auto n = B.size(0); + auto k = A.size(1); + size_t workspaceSize = CUBLAS_WORKSPACE_SIZE; + const int8_t fastAccuMode = use_fast_accum ? 1 : 0; + + TORCH_CHECK(A.is_cuda() && A.is_contiguous()); + TORCH_CHECK(B.is_cuda() && B.is_contiguous()); + + cublasLtHandle_t ltHandle; + checkCublasStatus(cublasLtCreate(<Handle)); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto workspace = allocator.allocate(workspaceSize); + if (output.has_value()) { + auto output_tensor = output.value(); + TORCH_CHECK(output_tensor.is_cuda()); + TORCH_CHECK(output_tensor.is_contiguous()); + TORCH_CHECK( + output_tensor.numel() == m * n, + "output_tensor.numel=", + output_tensor.numel(), + ", m=", + m, + ", n=", + n); + TORCH_CHECK(output_tensor.options().dtype() == at::kBFloat16); + } + + const cudaDataType_t A_type = CUDA_R_8F_E4M3; + const cudaDataType_t B_type = CUDA_R_8F_E4M3; + const cudaDataType_t D_type = CUDA_R_16BF; + + float one = 1.0; + float zero = 0.0; + + cublasOperation_t transa = CUBLAS_OP_T; + cublasOperation_t transb = CUBLAS_OP_N; + + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Ddesc = nullptr; + cublasLtMatmulPreference_t preference = nullptr; + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; + // Create matrix descriptors. Not setting any extra attributes. + + auto lda = k; + auto ldb = k; + auto ldd = n; + checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, A_type, k, m, lda)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, B_type, k, n, ldb)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&Ddesc, D_type, n, m, ldd)); + + checkCublasStatus( + cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); + checkCublasStatus(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + checkCublasStatus(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + + checkCublasStatus(cublasLtMatmulDescSetAttribute( + operationDesc, + CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, + sizeof(fastAccuMode))); + + if (Ainvs.has_value()) { + const float* Ainvs_pt = Ainvs.value().data_ptr(); + checkCublasStatus(cublasLtMatmulDescSetAttribute( + operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &Ainvs_pt, + sizeof(Ainvs_pt))); + } + + if (Binvs.has_value()) { + const float* Binvs_pt = Binvs.value().data_ptr(); + checkCublasStatus(cublasLtMatmulDescSetAttribute( + operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &Binvs_pt, + sizeof(Binvs_pt))); + } + + checkCublasStatus(cublasLtMatmulDescSetAttribute( + operationDesc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue))); + + checkCublasStatus(cublasLtMatmulPreferenceCreate(&preference)); + + checkCublasStatus(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspaceSize, + sizeof(workspaceSize))); + + checkCublasStatus(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + operationDesc, + Bdesc, + Adesc, + Ddesc, + Ddesc, + preference, + 1, + &heuristicResult, + &returnedResults)); + + if (returnedResults == 0) + throw std::runtime_error("Unable to find any suitable algorithms"); + + // D = alpha * (A * B) + beta * C + // Warmup + auto Y = output.value_or(at::empty({m, n}, A.options().dtype(at::kBFloat16))); + checkCublasStatus(cublasLtMatmul( + ltHandle, + operationDesc, + static_cast(&one), /* alpha */ + B.data_ptr(), /* B */ + Bdesc, + A.data_ptr(), /* A */ + Adesc, + static_cast(&zero), /* beta */ + nullptr, /* C */ + Ddesc, + Y.data_ptr(), /* D */ + Ddesc, + &heuristicResult.algo, /* algo */ + workspace.mutable_get(), /* workspace */ + workspaceSize, + at::cuda::getCurrentCUDAStream())); /* stream */ + return Y; +} + +#else + +at::Tensor f8f8bf16_cublas( + at::Tensor A, // FP8 + at::Tensor B, // FP8 + std::optional Ainvs = c10::nullopt, + std::optional Binvs = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu new file mode 100644 index 0000000000..1e27eef6e7 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu @@ -0,0 +1,494 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +// Cutlass rowwise kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM, + bool USE_BIAS, + typename INPUT_DTYPE, + typename BIAS_DTYPE> +at::Tensor f8f8bf16_rowwise_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + // XQ: M x K + // WQ: N x K + // output: M x N + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + TORCH_CHECK(XQ.size(-1) == K); + // 1. If the input tensor is {M, K}, the output tensor is {M, N}. + // 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}. + auto out_sizes = XQ.sizes().vec(); + out_sizes.back() = N; + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + at::Tensor Y; + if (output.has_value()) { + Y = output.value(); + // Make sure the provided output has the proper shape and dtype. + TORCH_CHECK(Y.sizes().vec() == out_sizes); + TORCH_CHECK(Y.dtype() == at::kBFloat16); + } else { + Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); + } + + using ElementInputA = INPUT_DTYPE; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementBias = BIAS_DTYPE; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementBias, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + cute::conditional_t< // Second stage output type. + USE_BIAS, + ElementBias, + ElementOutput>, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementOutput, // Final (optional) stage output type. + ElementBias, // Final stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeBias = + cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using SlowAccum = cute::conditional_t; + using FastAccum = + cute::conditional_t; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{})); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{})); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(M, N, cute::Int<1>{})); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + {reinterpret_cast(XQ.data_ptr()), + stride_a, + reinterpret_cast(WQ.data_ptr()), + stride_b}, + {{}, // Epilogue thread we populate below. + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + if constexpr (USE_BIAS) { + arguments.epilogue.thread = { + {reinterpret_cast(bias.value().data_ptr())}, // bias + // compute_1 + { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }, + {}, // Plus + }; + } else { + arguments.epilogue.thread = { + {reinterpret_cast( + x_scale.data_ptr())}, // x_scale + // compute_0 + { + {reinterpret_cast( + w_scale.data_ptr())}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +// FP8 Rowwise Cutlass kernel dispatch. +template +at::Tensor dispatch_fp8_rowwise_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + KernelMode kernel = get_kernel_mode(XQ, WQ); + if (kernel == KernelMode::Small) { + return f8f8bf16_rowwise_impl< + 64, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return f8f8bf16_rowwise_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } +} + +at::Tensor f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + if (bias.has_value()) { + TORCH_CHECK( + bias.value().dtype() == at::kFloat || + bias.value().dtype() == at::kBFloat16, + "Bias type must be bfloat16 or float32 if provided."); + } + bool use_bias = bias.has_value(); + bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; + + // Templatize based on input dtype. + bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; + + if (use_bias) { + if (bf16_bias) { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e5m2_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_kernel< + cutlass::float_e4m3_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } +} + +#else + +at::Tensor f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu index fea9cb9159..5fa1dec756 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu @@ -7,6 +7,7 @@ */ #include +#include // clang-format off // The fixed ordering of the headers is required for CUTLASS 3.2+ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h index caf500174c..2651676861 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h @@ -8,33 +8,17 @@ #pragma once -#if !( \ - defined(USE_ROCM) || \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) -#include -#include -#include -#include -#elif (defined(USE_ROCM)) -#include -#include -#include -#endif -#include -#include +#include #include #include #include +#include +#include #include #include #include #include -#include "cublas_utils.h" - -#if CUDART_VERSION >= 12000 -#include -#endif +#include // clang-format off // The fixed ordering of the headers is required for CUTLASS 3.2+ @@ -44,16 +28,7 @@ #include // @manual // clang-format on -#include -#include -#include -#include - -#include "cutlass_extensions/include/kernel_mode.h" -#include "fp8_blockwise_cutlass_helpers.h" - // Each block handles a single batch and head - // Each warp handles separate D dimension. // Load Q into registers in all warps. @@ -62,8 +37,7 @@ // Use shared reduction to compute max and compute softmax on shared memory. // Split T across warps in a block - -// each warp compute sum(t_subset) P[t] * V[t_subset, d] +// Each warp compute sum(t_subset) P[t] * V[t_subset, d] // outputs are of size float[D] namespace cutlass::epilogue::threadblock::detail {