From 63e6f609edee5cd1282a69c69211a748851d573e Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 7 Aug 2024 16:12:09 -0700 Subject: [PATCH] Break up cutlass_extensions.cu, pt1 (#2944) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/47 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2944 - Break up cutlass_extensions.cu Before - https://www.internalfb.com/buck2/86b77fae-04e1-44ca-b13f-d946417d6135: 8:14.5s Differential Revision: D60874248 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 3 + .../gen_ai/src/quantize/cutlass_extensions.cu | 696 +----------------- .../quantize/cutlass_extensions/i8i8bf16.cu | 322 ++++++++ .../cutlass_extensions/i8i8bf16_dynamic.cu | 195 +++++ .../cutlass_extensions/include/kernel_mode.h | 34 + .../cutlass_extensions/include/threadblock.h | 302 ++++++++ 6 files changed, 858 insertions(+), 694 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 01f1d6abd5..18796aaa47 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -15,6 +15,7 @@ set(fbgemm_sources_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${CMAKE_CURRENT_SOURCE_DIR}/../../../include + ${CMAKE_CURRENT_SOURCE_DIR}/src/quantize # PyTorch ${TORCH_INCLUDE_DIRS} # Third-party @@ -31,6 +32,8 @@ set(attention_ops_sources set(quantize_ops_sources src/quantize/cutlass_extensions.cu + src/quantize/cutlass_extensions/i8i8bf16.cu + src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu src/quantize/quantize.cu src/quantize/quantize.cpp) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu index 9229146c29..dda8c0184f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu @@ -53,257 +53,10 @@ #include #include +#include "cutlass_extensions/include/kernel_mode.h" +#include "cutlass_extensions/include/threadblock.h" #include "fp8_blockwise_cutlass_helpers.h" -// Each block handles a single batch and head - -// Each warp handles separate D dimension. - -// Load Q into registers in all warps. -// Split T across warps in a block -// Compute S[MAX_T] = for i in range(T): S[t] = sum(Q[d] * K[t, d]) -// Use shared reduction to compute max and compute softmax on shared memory. - -// Split T across warps in a block - -// each warp compute sum(t_subset) P[t] * V[t_subset, d] -// outputs are of size float[D] - -namespace cutlass::epilogue::threadblock::detail { - -/// Partial specialization for bfloat16 <= int32_t x 4 -template < - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename ThreadMap> -struct DefaultIteratorsTensorOp< - cutlass::bfloat16_t, - int32_t, - 8, - ThreadblockShape, - WarpShape, - InstructionShape, - ThreadMap> { - using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< - WarpShape, - InstructionShape, - int32_t, - layout::RowMajor>; - - using SharedLoadIterator = - cutlass::epilogue::threadblock::SharedLoadIterator; - - static int const kFragmentsPerIteration = 1; -}; - -} // namespace cutlass::epilogue::threadblock::detail - -// Wrapper to allow passing alpha/beta scaling params -// as device pointers. -namespace cutlass::epilogue::thread { - -template < - typename ElementOutput_, ///< Data type used to load and store tensors - int Count, ///< Number of elements computed per operation. - ///< Usually it is 128/sizeof_bits, - ///< but we use 64 or 32 sometimes when there are not enough data - ///< to store - typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type - typename ElementCompute_ = - ElementOutput_, ///< Data type used to compute linear combination - ScaleType::Kind Scale = - ScaleType::Default, ///< Control Alpha and Beta scaling - FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> -class LinearCombinationOnDevice { - public: - using ElementOutput = ElementOutput_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - - static int const kCount = Count; - static const ScaleType::Kind kScale = Scale; - using FragmentOutput = Array; - using FragmentAccumulator = Array; - using ComputeFragment = Array; - - using ParamsBase = LinearCombinationParams; - - static FloatRoundStyle const kRound = Round; - - /// Host-constructable parameters structure - struct Params : ParamsBase { - ElementCompute alpha; ///< scales accumulators - ElementCompute beta; ///< scales source tensor - ElementCompute const* alpha_ptr; ///< pointer to accumulator scalar - if not - ///< null, loads it from memory - ElementCompute const* beta_ptr; ///< pointer to source scalar - if not null, - ///< loads it from memory - - CUTLASS_HOST_DEVICE - Params() - : ParamsBase(ElementCompute(1), ElementCompute(0)), - alpha(ElementCompute(1)), - beta(ElementCompute(0)), - alpha_ptr(nullptr), - beta_ptr(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute alpha, ElementCompute beta) - : ParamsBase(alpha, beta), - alpha(alpha), - beta(beta), - alpha_ptr(nullptr), - beta_ptr(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute alpha) - : ParamsBase(alpha, ElementCompute(0)), - alpha(alpha), - beta(0), - alpha_ptr(nullptr), - beta_ptr(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute const* alpha_ptr, ElementCompute const* beta_ptr) - : ParamsBase(*alpha_ptr, *beta_ptr), - alpha(0), - beta(0), - alpha_ptr(alpha_ptr), - beta_ptr(beta_ptr) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute const* alpha_ptr) - : ParamsBase(ElementCompute(1), ElementCompute(0)), - alpha(0), - beta(0), - alpha_ptr(alpha_ptr), - beta_ptr(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(ParamsBase const& base) - : ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { -#if defined(__CUDA_ARCH__) - alpha = reinterpret_cast(base.alpha_data); - beta = reinterpret_cast(base.beta_data); -#else - memcpy(alpha, base.alpha_data, sizeof(ElementCompute)); - memcpy(beta, base.alpha_data, sizeof(ElementCompute)); -#endif - } - }; - - private: - // - // Data members - // - - const ElementCompute* alpha_ptr_; - ElementCompute beta_; - - public: - /// Constructs the function object, possibly loading from pointers in host - /// memory - CUTLASS_HOST_DEVICE - LinearCombinationOnDevice(Params const& params) { - alpha_ptr_ = params.alpha_ptr; - beta_ = ElementCompute(0); - // beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { - if (Scale == ScaleType::NoBetaScaling) - return true; - - if (Scale == ScaleType::OnlyAlphaScaling) - return false; - - if (Scale == ScaleType::Nothing) - return false; - - return beta_ != ElementCompute(0); - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - if (k_partition) { - beta_ = ElementCompute(1); - } - } - - /// Computes linear scaling: D = alpha * accumulator + beta * source - CUTLASS_HOST_DEVICE - FragmentOutput operator()( - FragmentAccumulator const& accumulator, - FragmentOutput const& source) const { - // Convert source to interal compute numeric type - NumericArrayConverter - source_converter; - NumericArrayConverter - accumulator_converter; - - // Convert to destination numeric type - NumericArrayConverter - destination_converter; - - ComputeFragment converted_source = source_converter(source); - ComputeFragment converted_accumulator = accumulator_converter(accumulator); - - if (Scale == ScaleType::Nothing) - return destination_converter(converted_accumulator); - - // Perform binary operations - ComputeFragment intermediate; - - multiplies mul_add_source; - multiply_add mul_add_accumulator; - - if (Scale == ScaleType::NoBetaScaling) - intermediate = converted_source; - else - intermediate = - mul_add_source(beta_, converted_source); // X = beta * C + uniform - - intermediate = mul_add_accumulator( - *alpha_ptr_, - converted_accumulator, - intermediate); // D = alpha * Accum + X - - return destination_converter(intermediate); - } - - /// Computes linear scaling: D = alpha * accumulator - CUTLASS_HOST_DEVICE - FragmentOutput operator()(FragmentAccumulator const& accumulator) const { - // Convert source to interal compute numeric type - NumericArrayConverter - accumulator_converter; - - // Convert to destination numeric type - NumericArrayConverter - destination_converter; - - ComputeFragment converted_accumulator = accumulator_converter(accumulator); - - if (Scale == ScaleType::Nothing) - return destination_converter(converted_accumulator); - - // Perform binary operations - ComputeFragment intermediate; - multiplies mul_accumulator; - - intermediate = mul_accumulator( - *alpha_ptr_, converted_accumulator); // D = alpha * Accum - - return destination_converter(intermediate); - } -}; - -} // namespace cutlass::epilogue::thread - namespace { int64_t ceil_div(int64_t a, int64_t b) { @@ -314,294 +67,7 @@ int64_t ceil_div(int64_t a, int64_t b) { namespace fbgemm_gpu { -template -at::Tensor i8i8bf16_impl( - at::Tensor XQ, // INT8 - at::Tensor WQ, // INT8 - double scale, - int64_t split_k) { - auto M = XQ.size(0); - auto N = WQ.size(0); - auto K = XQ.size(1); - - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); - - auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); - - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using Gemm = cutlass::gemm::device::Gemm< - int8_t, - cutlass::layout::RowMajor, - int8_t, - cutlass::layout::ColumnMajor, - ElementOutput, - cutlass::layout::RowMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape, // ThreadBlockShape - cutlass::gemm::GemmShape, // WarpShape - cutlass::gemm::GemmShape<16, 8, 32>, // InstructionShape - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementComputeEpilogue>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 3, - 16, - 16, - true>; - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - XQ.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - WQ.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - (ElementOutput*)Y.data_ptr(), - LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, - input_ref, - weight_ref, - out_ref, - out_ref, - {float(scale), 0.0}, - int(split_k)}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - auto workspace = - at::empty({int64_t(workspace_size)}, Y.options().dtype(at::kChar)); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize( - arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return Y; -} - -template -at::Tensor i8i8bf16sm90a_impl( - at::Tensor XQ, // INT8 - at::Tensor WQ, // INT8 - double scale) { - int M = XQ.size(0); - int N = WQ.size(0); - int K = XQ.size(1); - - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); - - auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); - - using ElementInputA = int8_t; - using LayoutInputA = cutlass::layout::RowMajor; - constexpr int AlignmentInputA = - 128 / - cutlass::sizeof_bits< - ElementInputA>::value; // Memory access granularity/alignment of A - // matrix in units of elements (up to 16 bytes) - - using ElementInputB = int8_t; - using LayoutInputB = cutlass::layout::ColumnMajor; - constexpr int AlignmentInputB = - 128 / - cutlass::sizeof_bits< - ElementInputB>::value; // Memory access granularity/alignment of B - // matrix in units of elements (up to 16 bytes) - - using ElementOutput = cutlass::bfloat16_t; - using LayoutOutput = cutlass::layout::ColumnMajor; - constexpr int AlignmentOutput = - 128 / - cutlass::sizeof_bits< - ElementOutput>::value; // Memory access granularity/alignment of C - // matrix in units of elements (up to 16 bytes) - - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Threadblock-level - // tile size - using ClusterShape = cute::Shape< - cute::Int, - cute::Int, - cute::Int>; // Shape of the - // threadblocks in a - // cluster - using StageCountType = - cutlass::gemm::collective::StageCountAuto; // Stage count maximized based - // on the tile size - using KernelSchedule = cutlass::gemm::collective:: - KernelScheduleAuto; // Kernel to launch based on the default setting in - // the Collective Builder - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementInputA, - LayoutInputA, - AlignmentInputA, - ElementInputB, - LayoutInputB, - AlignmentInputB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, - cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementComputeEpilogue, - ElementOutput, - LayoutOutput, - AlignmentOutput, - ElementOutput, - LayoutOutput, - AlignmentOutput, - cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideInputA = typename Gemm::GemmKernel::StrideA; - using StrideInputB = typename Gemm::GemmKernel::StrideB; - using StrideOutput = typename Gemm::GemmKernel::StrideC; - - StrideInputA stride_a = cutlass::make_cute_packed_stride( - StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{})); - StrideInputB stride_b = cutlass::make_cute_packed_stride( - StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{})); - StrideOutput stride_output = cutlass::make_cute_packed_stride( - StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{})); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {N, M, K}, - {WQ.data_ptr(), - stride_b, - XQ.data_ptr(), - stride_a}, - {{float(scale), 0}, - (ElementOutput*)Y.data_ptr(), - stride_output, - (ElementOutput*)Y.data_ptr(), - stride_output}}; - Gemm gemm; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm(at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return Y; -} - #if CUDART_VERSION >= 12000 -enum class KernelMode { Small, Large, Default }; - -KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { - auto M = XQ.size(0); - auto K = XQ.size(1); - auto N = WQ.size(0); - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || - (K >= 2048 && N >= 2048)); - if (M <= 128 || N <= 128) { - return KernelMode::Small; - } else if (use_large_kernel) { - return KernelMode::Large; - } else { - return KernelMode::Default; - } -} // Cutlass tensorwise kernel template < @@ -2480,162 +1946,4 @@ at::Tensor f8f8bf16_blockwise( } #endif -at::Tensor i8i8bf16( - at::Tensor XQ, // INT8 - at::Tensor WQ, // INT8 - double scale, - int64_t split_k) { - auto M = XQ.size(0); - auto N = WQ.size(0); - auto K = XQ.size(1); -#ifdef SMOOTHQUANT_SM90A - if (M <= 128) { - return i8i8bf16sm90a_impl<64, 128, 128, 2, 1, 1>(XQ, WQ, scale); - } else { - return i8i8bf16sm90a_impl<128, 128, 128, 1, 2, 1>(XQ, WQ, scale); - } -#else - if (M <= 128 && N >= K) { - return i8i8bf16_impl<64, 128, 64, 32, 64, 64>(XQ, WQ, scale, split_k); - } else if (M <= 128 && N < K) { - return i8i8bf16_impl<64, 64, 128, 32, 32, 128>(XQ, WQ, scale, split_k); - } else { - return i8i8bf16_impl<256, 128, 64, 64, 64, 64>(XQ, WQ, scale, split_k); - } -#endif -} - -template -at::Tensor i8i8bf16_dynamic_impl( - at::Tensor XQ, // INT8 - at::Tensor WQ, // INT8 - at::Tensor scale, - int64_t split_k) { - auto M = XQ.size(0); - auto N = WQ.size(0); - auto K = XQ.size(1); - - TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); - TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); - - auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); - - using ElementOutput = cutlass::bfloat16_t; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using Gemm = cutlass::gemm::device::Gemm< - int8_t, - cutlass::layout::RowMajor, - int8_t, - cutlass::layout::ColumnMajor, - ElementOutput, - cutlass::layout::RowMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape, // ThreadBlockShape - cutlass::gemm::GemmShape, // WarpShape - cutlass::gemm::GemmShape<16, 8, 32>, // InstructionShape - cutlass::epilogue::thread::LinearCombinationOnDevice< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementComputeEpilogue>, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - 3, - 16, - 16, - true>; - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - XQ.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - WQ.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - (ElementOutput*)Y.data_ptr(), - LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, - input_ref, - weight_ref, - out_ref, - out_ref, - {scale.data_ptr()}, - int(split_k)}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - auto workspace = - at::empty({int64_t(workspace_size)}, Y.options().dtype(at::kChar)); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize( - arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(at::cuda::getCurrentCUDAStream()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error( - std::string("cutlass cannot run") + - cutlass::cutlassGetStatusString(status)); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return Y; -} - -at::Tensor i8i8bf16_dynamic( - at::Tensor XQ, // INT8 - at::Tensor WQ, // INT8 - at::Tensor scale, - int64_t split_k) { - auto M = XQ.size(0); - auto N = WQ.size(0); - auto K = XQ.size(1); - if (M <= 128 && N >= K) { - return i8i8bf16_dynamic_impl<64, 128, 64, 32, 64, 64>( - XQ, WQ, scale, split_k); - } else if (M <= 128 && N < K) { - return i8i8bf16_dynamic_impl<64, 64, 128, 32, 32, 128>( - XQ, WQ, scale, split_k); - } else { - return i8i8bf16_dynamic_impl<256, 128, 64, 64, 64, 64>( - XQ, WQ, scale, split_k); - } -} - } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu new file mode 100644 index 0000000000..fea9cb9159 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu @@ -0,0 +1,322 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include +#include +#include +#include + +#include "cutlass_extensions/include/threadblock.h" + +namespace fbgemm_gpu { + +template +at::Tensor i8i8bf16_impl( + at::Tensor XQ, // INT8 + at::Tensor WQ, // INT8 + double scale, + int64_t split_k) { + auto M = XQ.size(0); + auto N = WQ.size(0); + auto K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape, // ThreadBlockShape + cutlass::gemm::GemmShape, // WarpShape + cutlass::gemm::GemmShape<16, 8, 32>, // InstructionShape + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementComputeEpilogue>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + true>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + XQ.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + WQ.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + (ElementOutput*)Y.data_ptr(), + LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, + input_ref, + weight_ref, + out_ref, + out_ref, + {float(scale), 0.0}, + int(split_k)}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + auto workspace = + at::empty({int64_t(workspace_size)}, Y.options().dtype(at::kChar)); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize( + arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +template +at::Tensor i8i8bf16sm90a_impl( + at::Tensor XQ, // INT8 + at::Tensor WQ, // INT8 + double scale) { + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementInputA = int8_t; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = + 128 / + cutlass::sizeof_bits< + ElementInputA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + using ElementInputB = int8_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = + 128 / + cutlass::sizeof_bits< + ElementInputB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = + 128 / + cutlass::sizeof_bits< + ElementOutput>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, cute::Int<1>{})); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, cute::Int<1>{})); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, cute::Int<1>{})); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K}, + {WQ.data_ptr(), + stride_b, + XQ.data_ptr(), + stride_a}, + {{float(scale), 0}, + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +at::Tensor i8i8bf16( + at::Tensor XQ, // INT8 + at::Tensor WQ, // INT8 + double scale, + int64_t split_k) { + auto M = XQ.size(0); + auto N = WQ.size(0); + auto K = XQ.size(1); +#ifdef SMOOTHQUANT_SM90A + if (M <= 128) { + return i8i8bf16sm90a_impl<64, 128, 128, 2, 1, 1>(XQ, WQ, scale); + } else { + return i8i8bf16sm90a_impl<128, 128, 128, 1, 2, 1>(XQ, WQ, scale); + } +#else + if (M <= 128 && N >= K) { + return i8i8bf16_impl<64, 128, 64, 32, 64, 64>(XQ, WQ, scale, split_k); + } else if (M <= 128 && N < K) { + return i8i8bf16_impl<64, 64, 128, 32, 32, 128>(XQ, WQ, scale, split_k); + } else { + return i8i8bf16_impl<256, 128, 64, 64, 64, 64>(XQ, WQ, scale, split_k); + } +#endif +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu new file mode 100644 index 0000000000..f3553f5eb3 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#if !( \ + defined(USE_ROCM) || \ + ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#include +#include +#include +#elif (defined(USE_ROCM)) +#include +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "cublas_utils.h" + +#if CUDART_VERSION >= 12000 +#include +#endif + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include +#include +#include +#include + +#include "cutlass_extensions/include/kernel_mode.h" +#include "cutlass_extensions/include/threadblock.h" +#include "fp8_blockwise_cutlass_helpers.h" + +namespace fbgemm_gpu { + +template +at::Tensor i8i8bf16_dynamic_impl( + at::Tensor XQ, // INT8 + at::Tensor WQ, // INT8 + at::Tensor scale, + int64_t split_k) { + auto M = XQ.size(0); + auto N = WQ.size(0); + auto K = XQ.size(1); + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); + + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape, // ThreadBlockShape + cutlass::gemm::GemmShape, // WarpShape + cutlass::gemm::GemmShape<16, 8, 32>, // InstructionShape + cutlass::epilogue::thread::LinearCombinationOnDevice< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementComputeEpilogue>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + true>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + XQ.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + WQ.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + (ElementOutput*)Y.data_ptr(), + LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, + input_ref, + weight_ref, + out_ref, + out_ref, + {scale.data_ptr()}, + int(split_k)}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + auto workspace = + at::empty({int64_t(workspace_size)}, Y.options().dtype(at::kChar)); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize( + arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +at::Tensor i8i8bf16_dynamic( + at::Tensor XQ, // INT8 + at::Tensor WQ, // INT8 + at::Tensor scale, + int64_t split_k) { + auto M = XQ.size(0); + auto N = WQ.size(0); + auto K = XQ.size(1); + if (M <= 128 && N >= K) { + return i8i8bf16_dynamic_impl<64, 128, 64, 32, 64, 64>( + XQ, WQ, scale, split_k); + } else if (M <= 128 && N < K) { + return i8i8bf16_dynamic_impl<64, 64, 128, 32, 32, 128>( + XQ, WQ, scale, split_k); + } else { + return i8i8bf16_dynamic_impl<256, 128, 64, 64, 64, 64>( + XQ, WQ, scale, split_k); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h new file mode 100644 index 0000000000..94a68096d4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace fbgemm_gpu { + +enum class KernelMode { Small, Large, Default }; + +inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { + auto M = XQ.size(0); + auto K = XQ.size(1); + auto N = WQ.size(0); + // Use a large kernel if at least two shapes are large.... + bool use_large_kernel = + ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || + (K >= 2048 && N >= 2048)); + if (M <= 128 || N <= 128) { + return KernelMode::Small; + } else if (use_large_kernel) { + return KernelMode::Large; + } else { + return KernelMode::Default; + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h new file mode 100644 index 0000000000..caf500174c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/threadblock.h @@ -0,0 +1,302 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#if !( \ + defined(USE_ROCM) || \ + ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#include +#include +#include +#elif (defined(USE_ROCM)) +#include +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "cublas_utils.h" + +#if CUDART_VERSION >= 12000 +#include +#endif + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include +#include +#include +#include + +#include "cutlass_extensions/include/kernel_mode.h" +#include "fp8_blockwise_cutlass_helpers.h" + +// Each block handles a single batch and head + +// Each warp handles separate D dimension. + +// Load Q into registers in all warps. +// Split T across warps in a block +// Compute S[MAX_T] = for i in range(T): S[t] = sum(Q[d] * K[t, d]) +// Use shared reduction to compute max and compute softmax on shared memory. + +// Split T across warps in a block + +// each warp compute sum(t_subset) P[t] * V[t_subset, d] +// outputs are of size float[D] + +namespace cutlass::epilogue::threadblock::detail { + +/// Partial specialization for bfloat16 <= int32_t x 4 +template < + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap> +struct DefaultIteratorsTensorOp< + cutlass::bfloat16_t, + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + int32_t, + layout::RowMajor>; + + using SharedLoadIterator = + cutlass::epilogue::threadblock::SharedLoadIterator; + + static int const kFragmentsPerIteration = 1; +}; + +} // namespace cutlass::epilogue::threadblock::detail + +// Wrapper to allow passing alpha/beta scaling params +// as device pointers. +namespace cutlass::epilogue::thread { + +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = + ElementOutput_, ///< Data type used to compute linear combination + ScaleType::Kind Scale = + ScaleType::Default, ///< Control Alpha and Beta scaling + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class LinearCombinationOnDevice { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + static const ScaleType::Kind kScale = Scale; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + using ParamsBase = LinearCombinationParams; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params : ParamsBase { + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const* alpha_ptr; ///< pointer to accumulator scalar - if not + ///< null, loads it from memory + ElementCompute const* beta_ptr; ///< pointer to source scalar - if not null, + ///< loads it from memory + + CUTLASS_HOST_DEVICE + Params() + : ParamsBase(ElementCompute(1), ElementCompute(0)), + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, ElementCompute beta) + : ParamsBase(alpha, beta), + alpha(alpha), + beta(beta), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha) + : ParamsBase(alpha, ElementCompute(0)), + alpha(alpha), + beta(0), + alpha_ptr(nullptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const* alpha_ptr, ElementCompute const* beta_ptr) + : ParamsBase(*alpha_ptr, *beta_ptr), + alpha(0), + beta(0), + alpha_ptr(alpha_ptr), + beta_ptr(beta_ptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const* alpha_ptr) + : ParamsBase(ElementCompute(1), ElementCompute(0)), + alpha(0), + beta(0), + alpha_ptr(alpha_ptr), + beta_ptr(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(ParamsBase const& base) + : ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { +#if defined(__CUDA_ARCH__) + alpha = reinterpret_cast(base.alpha_data); + beta = reinterpret_cast(base.beta_data); +#else + memcpy(alpha, base.alpha_data, sizeof(ElementCompute)); + memcpy(beta, base.alpha_data, sizeof(ElementCompute)); +#endif + } + }; + + private: + // + // Data members + // + + const ElementCompute* alpha_ptr_; + ElementCompute beta_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + LinearCombinationOnDevice(Params const& params) { + alpha_ptr_ = params.alpha_ptr; + beta_ = ElementCompute(0); + // beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + if (Scale == ScaleType::NoBetaScaling) + return true; + + if (Scale == ScaleType::OnlyAlphaScaling) + return false; + + if (Scale == ScaleType::Nothing) + return false; + + return beta_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& accumulator, + FragmentOutput const& source) const { + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + if (Scale == ScaleType::Nothing) + return destination_converter(converted_accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + if (Scale == ScaleType::NoBetaScaling) + intermediate = converted_source; + else + intermediate = + mul_add_source(beta_, converted_source); // X = beta * C + uniform + + intermediate = mul_add_accumulator( + *alpha_ptr_, + converted_accumulator, + intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const& accumulator) const { + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + if (Scale == ScaleType::Nothing) + return destination_converter(converted_accumulator); + + // Perform binary operations + ComputeFragment intermediate; + multiplies mul_accumulator; + + intermediate = mul_accumulator( + *alpha_ptr_, converted_accumulator); // D = alpha * Accum + + return destination_converter(intermediate); + } +}; + +} // namespace cutlass::epilogue::thread