From 7105d5f70f47f5edd2f7805ad835d5a29e6db2bd Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Thu, 15 Aug 2024 05:41:42 -0700 Subject: [PATCH] Add kv cache related ops (#2968) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2968 X-link: https://github.com/facebookresearch/FBGEMM/pull/65 Reviewed By: Aya-ZIbra Differential Revision: D60951847 fbshipit-source-id: 9d3fa828fcdb3c905c8cf8ca6fbce2066252724b --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 7 +- .../gen_ai/src/attention/gqa_attn_splitk.cu | 453 +---- .../experimental/gen_ai/src/comm/car.cu | 73 +- .../gen_ai/src/kv_cache/kv_cache.cpp | 271 +++ .../gen_ai/src/kv_cache/kv_cache.cu | 1527 +++++++++++++++++ .../gen_ai/src/quantize/quantize.cu | 125 +- .../gen_ai/test/kv_cache/kv_cache_test.py | 578 +++++++ .../gen_ai/test/kv_cache/rope_padded.py | 331 ++++ .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 4 + .../include/fbgemm_gpu/utils/vec_quant.cuh | 431 +++++ 10 files changed, 3182 insertions(+), 618 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp create mode 100644 fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py create mode 100644 fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py create mode 100644 fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 0ffbad7182..7716c39c17 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -56,10 +56,15 @@ set(comm_ops_sources src/comm/car.cu src/comm/car.cpp) +set(kv_cache_ops_sources + src/kv_cache/kv_cache.cu + src/kv_cache/kv_cache.cpp) + set(experimental_gen_ai_cpp_source_files ${attention_ops_sources} ${quantize_ops_sources} - ${comm_ops_sources}) + ${comm_ops_sources} + ${kv_cache_ops_sources}) # Set the source file for FB only CPP if(USE_FB_ONLY) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu index 4776d96f99..59d4646c89 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu @@ -13,21 +13,6 @@ /// @defgroup experimental-gen-ai-attention /// This is a description of Grouped Query Attention operators. -#if !( \ - defined(USE_ROCM) || \ - ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) -#include -#include -#elif (defined(USE_ROCM)) -#include -#include -#endif - -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) -#include -#endif - #ifndef USE_ROCM #include #endif @@ -38,43 +23,7 @@ #define USE_WMMA_FRAG #endif -#ifdef USE_ROCM -constexpr int32_t kThreadsPerWarp = 64; -constexpr int32_t kWarpsPerBlock = 16; -#else -constexpr int32_t kThreadsPerWarp = 32; -constexpr int32_t kWarpsPerBlock = 32; -#endif - -#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) -#define FINAL_MASK 0xffffffff - -namespace fbgemm_gpu::gen_ai::attention { - -constexpr int32_t D_H = 128; -constexpr int32_t MAX_T = 16384; -constexpr int SMEM_ADJUST_THRESHOLD = 48 * 1024; - -constexpr int kMaxHeads = 8; -// Fragments shapes used for wmma tensor core operations -constexpr int F_M = 8, F_N = 32, F_K = 16; -constexpr int SMEM_K_PAD = 2; -constexpr int SMEM_V_PAD = 2; -constexpr int SMEM_K_STRIDE = F_K + SMEM_K_PAD; -constexpr int SMEM_V_STRIDE = F_N + SMEM_V_PAD; - -// Use fewer warps for gqa_attn_splitk_wmma_kernel -constexpr int32_t kSplitKWarpsPerBlock = 4; - -namespace { - -static __host__ DEVICE_INLINE int32_t div_up(int32_t a, int32_t b) { - return (a + b - 1) / b; -}; - -static __host__ DEVICE_INLINE int32_t round_up(int32_t a, int32_t b) { - return ((a + b - 1) / b) * b; -} +#include template void set_gpu_max_dynamic_shared_memory( @@ -83,14 +32,14 @@ void set_gpu_max_dynamic_shared_memory( const int device) { // V100: 96 KB; A100: 160 KB; H100: 228 KB. int max_shared_bytes = 0; - cudaDeviceGetAttribute( + C10_CUDA_CHECK(cudaDeviceGetAttribute( &max_shared_bytes, #ifndef __HIP_PLATFORM_AMD__ cudaDevAttrMaxSharedMemoryPerBlockOptin, #else hipDeviceAttributeMaxSharedMemoryPerBlock, #endif - device); + device)); C10_CUDA_KERNEL_LAUNCH_CHECK(); TORCH_CHECK( smem_bytes <= max_shared_bytes, @@ -105,380 +54,25 @@ void set_gpu_max_dynamic_shared_memory( C10_CUDA_KERNEL_LAUNCH_CHECK(); } -#ifdef __HIP_PLATFORM_AMD__ -using __nv_bfloat16 = hip_bfloat16; - -typedef struct __align__(4) { - uint16_t x; - uint16_t y; -} -__nv_bfloat162_raw; - -struct __align__(4) __nv_bfloat162 { - __nv_bfloat16 x; - __nv_bfloat16 y; -}; - -// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical -// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC -static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) { - __nv_bfloat16 output; - return output.round_to_bfloat16(f); -} - -static __host__ __device__ __nv_bfloat16 __float2bfloat16_rn(float f) { - __nv_bfloat16 output; - return output.round_to_bfloat16(f); -} - -static __host__ __device__ float __bfloat162float(__nv_bfloat16 f) { - // float output; - // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/hip__bfloat16_8h_source.html - return float(f); -} - -static __host__ __device__ __nv_bfloat162 -__floats2bfloat162_rn(float x, float y) { - __nv_bfloat162 output; - output.x = __float2bfloat16_rn(x); - output.y = __float2bfloat16_rn(y); - return output; -} - -#endif - -struct __align__(16) fx4 { - float x; - float y; - float z; - float w; - __host__ __device__ fx4() { - x = 0; - y = 0; - z = 0; - w = 0; - } -}; - -struct __align__(8) bfx4 { - __nv_bfloat162 vals[2]; -}; - -// TODO: Include the following code from fbgemm_gpu header -struct __align__(16) bfx8 { - __nv_bfloat162 vals[4]; -}; - -struct __align__(8) halfx4 { - __half2 vals[2]; -}; - -struct __align__(16) halfx8 { - __half2 vals[4]; -}; - -// Reinterpret a pair of uint16_t (packed into a uint32_t) as half2, and -// multiply by rhs. -DEVICE_INLINE __half2 hmul_short2(uint32_t lhs, __half rhs) { -#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 -#ifndef __HALF2_TO_UI -// cuda_fp16.hpp -#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) -#endif -#ifndef __HALF2_TO_CUI -// cuda_fp16.hpp -#define __HALF2_TO_CUI(var) *(reinterpret_cast(&(var))) -#endif - __half2 ret; - __half2 rhsp = make_half2(rhs, rhs); - asm("mul.f16x2 %0, %1, %2;" - : "=r"(__HALF2_TO_UI(ret)) - : "r"(__HALF2_TO_CUI(lhs)), "r"(__HALF2_TO_CUI(rhsp))); - return ret; -#else -#ifndef __HALF2_TO_UI -// cuda_fp16.hpp -#define __HALF2_TO_UI(var) *(reinterpret_cast(&(var))) -#endif - __half2 lhs_h2; - __HALF2_TO_UI(lhs_h2) = lhs; - float2 fx = __half22float2(lhs_h2); - float2 fy = __half22float2(make_half2(rhs, rhs)); - float2 fr; - fr.x = fx.x * fy.x; - fr.y = fx.y * fy.y; - return __float22half2_rn(fr); -#endif -} - -__forceinline__ __device__ bfx8 -dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { - halfx8 res; - uint32_t v = packedVals; - // What's going on here, you might ask? We extra out 4-bit pairs of integers - // as 2xuint16 packed into an int32 via the mask operation, and then we - // convert them to half precision values. As these are all integers in [0, - // 15], we can actually just interpret the 4-bit integer values as - // half-precision values. We multiply by 4096 x 4096 to go from the 4-bit - // representation to the equivalent fp16 value, or alternatively 32768 * 512 - // (or 32 when we have shifted the 4-bit value up). See e.g. - // https://gist.github.com/ajtulloch/021254a291a95966bc509db4e34ffeff for a - // NumPy implementation. We do this dance because: a) doing bitwise operations - // on each 4-bit value is expensive on the ALU, and 4-bit to half is expensive - // on the XU. b) doing a 256-entry shared memory LUT on 8-bit pairs is - // expensive on SMEM throughput. Credit to @jhj. - res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); - res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); - v >>= 8; - res.vals[2] = hmul_short2(v & 0x000F000F, __float2half(32768)); - res.vals[3] = hmul_short2(v & 0x00F000F0, __float2half(32768)); - - // ~5% perf gain is observed with the explicit type conversions using - // __float2half on Nvidia A100 GPUs (https://fburl.com/diff/ss8372zw) using - // NVCC 11.0. Additionally, HIP compiler requires these explicit type - // conversions. - half shift_scale_x = __low2half(shift_scale); - half shift_scale_y = __high2half(shift_scale); - - // now, dequantize - auto shifts = __half2(shift_scale_y, shift_scale_y); - auto scales_lower_temp = __hmul(shift_scale_x, __float2half(512)); - auto scales_lower = __half2(scales_lower_temp, scales_lower_temp); - auto scales_upper_temp = __hmul(shift_scale_x, __float2half(32)); - auto scales_upper = __half2(scales_upper_temp, scales_upper_temp); - - auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); - auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); - auto r2 = __half22float2(__hfma2(res.vals[2], scales_lower, shifts)); - auto r3 = __half22float2(__hfma2(res.vals[3], scales_upper, shifts)); - - bfx8 result; - result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); - result.vals[1] = __floats2bfloat162_rn(r2.x, r3.x); - result.vals[2] = __floats2bfloat162_rn(r0.y, r1.y); - result.vals[3] = __floats2bfloat162_rn(r2.y, r3.y); - - return result; -} - -// struct __align__(16) bfx8 { -// __nv_bfloat162 vals[4]; -// }; - -// DEVICE_INLINE bfx4 dequantize_packed_int4(uint16_t vs, __half2 -// shift_scale_0); DEVICE_INLINE bfx8 dequantize_packed_int4( -// uint32_t v, -// __half2 shift_scale_0, -// __half2 shift_scale_1); -// DEVICE_INLINE bfx8 -// dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale); - -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) -DEVICE_INLINE bfx4 dequantize_packed_fp8(uint32_t vs, __half2 shift_scale_0) { - uint32_t v = vs; - __nv_fp8_e4m3* fp8_k = reinterpret_cast<__nv_fp8_e4m3*>(&v); // 4 element - - auto shift_0 = float(__high2half(shift_scale_0)); - auto scale_0 = float(__low2half(shift_scale_0)); - - // now, dequantize - auto r0 = make_float2( - float(fp8_k[0]) * scale_0 + shift_0, float(fp8_k[1]) * scale_0 + shift_0); - auto r1 = make_float2( - float(fp8_k[2]) * scale_0 + shift_0, float(fp8_k[3]) * scale_0 + shift_0); - - bfx4 result; - result.vals[0] = __floats2bfloat162_rn(r0.x, r0.y); - result.vals[1] = __floats2bfloat162_rn(r1.x, r1.y); - return result; -} -#endif - -DEVICE_INLINE bfx4 dequantize_packed_int4(uint16_t vs, __half2 shift_scale_0) { - uint32_t v = vs; - // move 2nd byte to 3rd byte, so our bits are in 0x00FF00FF positions. - v = (v & 0xFF) | ((v & 0xFF00) << 8); - - halfx4 res; - res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); - res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); - - // ~5% perf gain is observed with the explicit type conversions using - // __float2half on Nvidia A100 GPUs (https://fburl.com/diff/ss8372zw) using - // NVCC 11.0. Additionally, HIP compiler requires these explicit type - // conversions. - half shift_scale_0_x = __low2half(shift_scale_0); - half shift_scale_0_y = __high2half(shift_scale_0); - - // now, dequantize - auto shifts = __half2(shift_scale_0_y, shift_scale_0_y); - auto scales_lower = __half2( - __hmul(shift_scale_0_x, __float2half(512)), - __hmul(shift_scale_0_x, __float2half(512))); - auto scales_upper = __half2( - __hmul(shift_scale_0_x, __float2half(32)), - __hmul(shift_scale_0_x, __float2half(32))); - - auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); - auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); - - bfx4 result; - result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); - result.vals[1] = __floats2bfloat162_rn(r0.y, r1.y); - return result; -} - -DEVICE_INLINE bfx8 dequantize_packed_int4( - uint32_t v, - __half2 shift_scale_0, - __half2 shift_scale_1) { - halfx8 res; - res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); - res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); - v >>= 8; - res.vals[2] = hmul_short2(v & 0x000F000F, __float2half(32768)); - res.vals[3] = hmul_short2(v & 0x00F000F0, __float2half(32768)); - - half shift_scale_0_x = __low2half(shift_scale_0); - half shift_scale_0_y = __high2half(shift_scale_0); - half shift_scale_1_x = __low2half(shift_scale_1); - half shift_scale_1_y = __high2half(shift_scale_1); - - // now, dequantize - auto shifts = __half2(shift_scale_0_y, shift_scale_1_y); - auto scales_lower = __half2( - __hmul(shift_scale_0_x, __float2half(512)), - __hmul(shift_scale_1_x, __float2half(512))); - auto scales_upper = __half2( - __hmul(shift_scale_0_x, __float2half(32)), - __hmul(shift_scale_1_x, __float2half(32))); - - auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); - auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); - auto r2 = __half22float2(__hfma2(res.vals[2], scales_lower, shifts)); - auto r3 = __half22float2(__hfma2(res.vals[3], scales_upper, shifts)); - - bfx8 result; - result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); - result.vals[1] = __floats2bfloat162_rn(r2.x, r3.x); - result.vals[2] = __floats2bfloat162_rn(r0.y, r1.y); - result.vals[3] = __floats2bfloat162_rn(r2.y, r3.y); - return result; -} - -DEVICE_INLINE float2 bf1622float2(const __nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#elif defined(USE_ROCM) - float2 f_val; - f_val.x = __bfloat162float(val.x); - f_val.y = __bfloat162float(val.y); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -#define CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(NAME, NUM_GROUPS, ...) \ - switch (NUM_GROUPS) { \ - case 1: \ - NAME(1, __VA_ARGS__); \ - break; \ - case 2: \ - NAME(2, __VA_ARGS__); \ - break; \ - case 4: \ - NAME(4, __VA_ARGS__); \ - break; \ - case 8: \ - NAME(8, __VA_ARGS__); \ - break; \ - case 16: \ - TORCH_CHECK( \ - false, \ - "With head dim = 128 we're almost even with int8 at this point. Are you sure about this? Num groups:", \ - NUM_GROUPS); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported number of groups: ", NUM_GROUPS); \ - } - -DEVICE_INLINE float bfx4_dot(bfx4 a, bfx4 b) { - // float2 acc = {0, 0}; - // __nv_bfloat162 acc; - // acc.x = static_cast(0); - // acc.y = static_cast(0); - // TODO: need to be performed in float32? - auto a0 = bf1622float2(a.vals[0]); - auto a1 = bf1622float2(a.vals[1]); - auto b0 = bf1622float2(b.vals[0]); - auto b1 = bf1622float2(b.vals[1]); - return a0.x * b0.x + a0.y * b0.y + a1.x * b1.x + a1.y * b1.y; - - // acc = __hfma2(a.vals[0], b.vals[0], acc); - // acc = __hfma2(a.vals[1], b.vals[1], acc); - // auto r = bf1622float2(acc); - // return r.x + r.y; -} - -DEVICE_INLINE fx4 bfx4_scale_acc(fx4 acc, bfx4 a, float b) { - auto axy = bf1622float2(a.vals[0]); - auto azw = bf1622float2(a.vals[1]); - acc.x += axy.x * b; - acc.y += axy.y * b; - acc.z += azw.x * b; - acc.w += azw.y * b; - return acc; -} - -DEVICE_INLINE fx4 fx4_acc(fx4 a, fx4 b) { - a.x += b.x; - a.y += b.y; - a.z += b.z; - a.w += b.w; - return a; -} +namespace fbgemm_gpu::gen_ai::attention { -DEVICE_INLINE bfx4 fx4_to_bfx4(fx4 a) { - bfx4 r; - r.vals[0] = __floats2bfloat162_rn(a.x, a.y); - r.vals[1] = __floats2bfloat162_rn(a.z, a.w); - return r; -} +constexpr int32_t D_H = 128; +constexpr int32_t MAX_T = 16384; +constexpr int SMEM_ADJUST_THRESHOLD = 48 * 1024; -template -DEVICE_INLINE T shfl_xor( - unsigned shfl_sync_mask, - const T val, - int laneMask, - int width = kThreadsPerWarp) { -#if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION < 9000 - return __shfl_xor(val, laneMask, width); -#else - return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width); -#endif -} +constexpr int kMaxHeads = 8; +// Fragments shapes used for wmma tensor core operations +constexpr int F_M = 8, F_N = 32, F_K = 16; +constexpr int SMEM_K_PAD = 2; +constexpr int SMEM_V_PAD = 2; +constexpr int SMEM_K_STRIDE = F_K + SMEM_K_PAD; +constexpr int SMEM_V_STRIDE = F_N + SMEM_V_PAD; -template -DEVICE_INLINE T warpReduceSum(T val, uint32_t warp_mask = FINAL_MASK) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val += shfl_xor(warp_mask, val, mask, 32); - return val; -} +// Use fewer warps for gqa_attn_splitk_wmma_kernel +constexpr int32_t kSplitKWarpsPerBlock = 4; -template -DEVICE_INLINE T warpReduceMax(T val, uint32_t warp_mask = FINAL_MASK) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = max(val, shfl_xor(warp_mask, val, mask, 32)); - return val; -} +namespace { -enum class CacheLogicalDtype { BF16, FP8, INT4 }; template < typename kv_t, int KVQuantNumGroups = 1, @@ -829,7 +423,7 @@ __global__ void __launch_bounds__(kThreadsPerWarp* kSplitKWarpsPerBlock, 1) __syncthreads(); } - const int hPerWarp = div_up(h_total_per_block, kSplitKWarpsPerBlock); + const int hPerWarp = div_round_up(h_total_per_block, kSplitKWarpsPerBlock); const int h_begin = warp_idx * hPerWarp; const int h_end = min(h_begin + hPerWarp, h_total_per_block); @@ -903,7 +497,7 @@ __global__ void __launch_bounds__(kThreadsPerWarp* kSplitKWarpsPerBlock, 1) __nv_bfloat16* smem_bf16 = reinterpret_cast<__nv_bfloat16*>(smem); float2 p[CONV_UNROLLS]; const int t_stride = blockDim.x * blockDim.y * 2; - const int t_rounds = div_up(t_total_per_block, t_stride); + const int t_rounds = div_round_up(t_total_per_block, t_stride); const int global_tid = warp_idx * blockDim.x + threadIdx.x; // Ensure that all threads finish writing to smem before modifying it in the @@ -917,7 +511,8 @@ __global__ void __launch_bounds__(kThreadsPerWarp* kSplitKWarpsPerBlock, 1) auto* smem_fp32_ = smem + t_start; auto* smem_bf16_ = smem_bf16 + t_start; - for (int h_i = 0; h_i < div_up(h_total_per_block, CONV_UNROLLS); h_i++) { + for (int h_i = 0; h_i < div_round_up(h_total_per_block, CONV_UNROLLS); + h_i++) { // Read FP32 #pragma unroll for (int h_j = 0; h_j < CONV_UNROLLS; h_j++) { @@ -1209,7 +804,7 @@ __global__ void gqa_attn_splitk_reduce_wmma_kernel( const int32_t t_max = seq_positions[b] + 1; const int32_t t_total = round_up(t_max, num_split_ks); const int32_t t_per_block = t_total / num_split_ks; - const int32_t num_split_ks_max = div_up(t_max, t_per_block); + const int32_t num_split_ks_max = div_round_up(t_max, t_per_block); for (int k = 1; k < num_split_ks_max; ++k) { float m_k = metadata[b][0][k][h]; @@ -1829,7 +1424,7 @@ std::tuple gqa_attn_splitk_wmma_impl( auto metadata = at::empty({B, 2, num_split_ks, H}, out_splitK.options()); // TODO: Check if the grid size is valid - const int32_t H_blocks = div_up(H, kMaxHeads); + const int32_t H_blocks = div_round_up(H, kMaxHeads); dim3 blocks(B, H_blocks, num_split_ks); dim3 threads(kThreadsPerWarp, kSplitKWarpsPerBlock); @@ -1837,7 +1432,7 @@ std::tuple gqa_attn_splitk_wmma_impl( return {O, out_splitK, metadata}; } - const int32_t t_per_block = div_up(cache_K.size(1), num_split_ks); + const int32_t t_per_block = div_round_up(cache_K.size(1), num_split_ks); // This is called ldc inside gqa_attn_splitk_wmma_kernel kernel const int32_t t_per_block_round_up = round_up(t_per_block, F_N); diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu index d09a339af1..03dd261fe3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu @@ -30,78 +30,15 @@ #include #endif #include -// #include "cuda_dispatch_utils.h" #include #include -#if ( \ - defined(__CUDA_ARCH__) && \ - ((__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900))) -#define USE_WMMA_FRAG -#endif +#include namespace fbgemm_gpu { -#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) - -static __host__ DEVICE_INLINE int32_t div_up(int32_t a, int32_t b) { - return (a + b - 1) / b; -}; - -#ifdef __HIP_PLATFORM_AMD__ -constexpr int32_t kThreadsPerWarp = 64; -#else -constexpr int32_t kThreadsPerWarp = 32; -#endif - -#ifdef __HIP_PLATFORM_AMD__ -using __nv_bfloat16 = hip_bfloat16; - -typedef struct __align__(4) { - uint16_t x; - uint16_t y; -} -__nv_bfloat162_raw; - -struct __align__(4) __nv_bfloat162 { - __nv_bfloat16 x; - __nv_bfloat16 y; -}; - -// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical -// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC -static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) { - __nv_bfloat16 output; - return output.round_to_bfloat16(f); -} - -static __host__ __device__ __nv_bfloat16 __float2bfloat16_rn(float f) { - __nv_bfloat16 output; - return output.round_to_bfloat16(f); -} - -static __host__ __device__ float __bfloat162float(__nv_bfloat16 f) { - // float output; - // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/hip__bfloat16_8h_source.html - return float(f); -} - -static __host__ __device__ __nv_bfloat162 -__floats2bfloat162_rn(float x, float y) { - __nv_bfloat162 output; - output.x = __float2bfloat16_rn(x); - output.y = __float2bfloat16_rn(y); - return output; -} - -#endif - -struct __align__(16) bf16x8 { - __nv_bfloat162 vals[4]; -}; - DEVICE_INLINE __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -545,14 +482,14 @@ void one_shot_car_allreduce( dim3 threads(0, 1, 1); dim3 blocks(0, 1, 1); if (N < N_per_thread * kThreadsPerBlock) { - threads.x = div_up(N, N_per_warp) * kThreadsPerWarp; + threads.x = div_round_up(N, N_per_warp) * kThreadsPerWarp; blocks.x = 1; } else { - auto warps_required = div_up(N, N_per_warp); + auto warps_required = div_round_up(N, N_per_warp); blocks.x = std::min( - cuda_calc_block_count(div_up(N, N_per_thread), kThreadsPerBlock), + cuda_calc_block_count(div_round_up(N, N_per_thread), kThreadsPerBlock), kMaxBlocks); - auto warps_per_block = div_up(warps_required, blocks.x); + auto warps_per_block = div_round_up(warps_required, blocks.x); auto threads_per_block = std::min(kThreadsPerBlock, warps_per_block * kThreadsPerWarp); diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp new file mode 100644 index 0000000000..6caeca3fc3 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include +#include +#include +#include + +#include "c10/util/Exception.h" + +namespace fbgemm_gpu { + +#define DEFAULT_PAGE_SIZE 64 +#define STRING_(s) #s +#define STRING(x) STRING_(x) + +at::Tensor rope_qkv_varseq_prefill( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor varseq_batch, + at::Tensor varseq_seqpos, + double theta, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional varseq_cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +at::Tensor rope_qkv_decoding( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seqpos, + double theta, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional actual_batch_size, + std::optional batch, + std::optional cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +at::Tensor xpos_qkv_varseq_prefill( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor varseq_batch, + at::Tensor varseq_seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional varseq_cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +at::Tensor xpos_qkv_decoding( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional actual_batch_size, + std::optional batch, + std::optional cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +std::tuple dequantize_int4_cache( + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor kv_seqlen, + std::optional num_groups); + +std::tuple dequantize_fp8_cache( + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor kv_seqlen); + +at::Tensor mqa_attn( + at::Tensor XQ, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seq_positions, + double qk_scale, + std::optional num_groups, + int64_t cache_logical_dtype_int); + +#define DEFAULT_PAGE_SIZE 64 +#define STRING_(s) #s +#define STRING(x) STRING_(x) + +at::Tensor rope_qkv_varseq_prefill( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor varseq_batch, + at::Tensor varseq_seqpos, + double theta, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional varseq_cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +at::Tensor rope_qkv_decoding( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seqpos, + double theta, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional actual_batch_size, + std::optional batch, + std::optional cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +at::Tensor xpos_qkv_varseq_prefill( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor varseq_batch, + at::Tensor varseq_seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional varseq_cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +at::Tensor xpos_qkv_decoding( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional actual_batch_size, + std::optional batch, + std::optional cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling, + int64_t old_context_len, + double scaling_factor, + double lo_freq_factor, + double hi_freq_factor); + +std::tuple dequantize_int4_cache( + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor kv_seqlen, + std::optional num_groups); + +std::tuple dequantize_fp8_cache( + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor kv_seqlen); + +at::Tensor mqa_attn( + at::Tensor XQ, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seq_positions, + double qk_scale, + std::optional num_groups, + int64_t cache_logical_dtype_int); + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( + DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor"); + m.impl("rope_qkv_varseq_prefill", rope_qkv_varseq_prefill); + m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( + DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor"); + m.impl("rope_qkv_decoding", rope_qkv_decoding); + m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( + DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor"); + m.impl("xpos_qkv_varseq_prefill", xpos_qkv_varseq_prefill); + m.def("xpos_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING( + DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor"); + m.impl("xpos_qkv_decoding", xpos_qkv_decoding); + + m.def( + "dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1) -> (Tensor, Tensor)"); + m.impl("dequantize_int4_cache", dequantize_int4_cache); + m.def( + "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen) -> (Tensor, Tensor)"); + m.impl("dequantize_fp8_cache", dequantize_fp8_cache); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu new file mode 100644 index 0000000000..871e2bd11d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -0,0 +1,1527 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "c10/core/ScalarType.h" +#include "c10/util/BFloat16.h" + +#ifndef USE_ROCM +#include +#endif +#include + +#include +#include + +#include + +template +void set_gpu_max_dynamic_shared_memory( + func_t kernel, + const int smem_bytes, + const int device) { + // V100: 96 KB; A100: 160 KB; H100: 228 KB. + int max_shared_bytes = 0; + C10_CUDA_CHECK(cudaDeviceGetAttribute( + &max_shared_bytes, +#ifndef __HIP_PLATFORM_AMD__ + cudaDevAttrMaxSharedMemoryPerBlockOptin, +#else + hipDeviceAttributeMaxSharedMemoryPerBlock, +#endif + device)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + TORCH_CHECK( + smem_bytes <= max_shared_bytes, + "Try to allocate ", + smem_bytes / 1024, + " KB of shared memory but only ", + max_shared_bytes / 1024, + " KB is available"); + + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void*)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +namespace fbgemm_gpu { + +template +__global__ void dequantize_int4_cache_kernel( + at::PackedTensorAccessor64 + cache_K, // [B][MAX_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V, // [B][MAX_T][N_KVH][D_H // G] + at::PackedTensorAccessor32 kv_seqlen, + at::PackedTensorAccessor64 + cache_K_dq, // [B][MAX_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V_dq // [B][MAX_T][N_KVH][D_H] +) { + auto N_KVH = cache_K.size(2); + auto MAX_T = cache_K.size(1); + auto D_H = cache_K_dq.size(3); + + auto b = blockIdx.x; + // only need to dequantize this far. + auto max_t = kv_seqlen[b]; + + // one warp per T/H + for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; + t_h += blockDim.y * gridDim.y) { + auto h = t_h % N_KVH; + auto t = t_h / N_KVH; + + auto* row_k = &cache_K[b][t][h][0]; + auto* row_v = &cache_V[b][t][h][0]; + bfx8 kv_dq; + if (KVQuantNumGroups == 1) { + __half2 k_shift_scale; + __half2 v_shift_scale; + *reinterpret_cast(&k_shift_scale) = + *reinterpret_cast(&row_k[0]); + *reinterpret_cast(&v_shift_scale) = + *reinterpret_cast(&row_v[0]); + if (4 * threadIdx.x >= D_H) { + continue; + } + uint32_t kq = *reinterpret_cast(&row_k[threadIdx.x * 2 + 4]); + uint32_t vq = *reinterpret_cast(&row_v[threadIdx.x * 2 + 4]); + + uint32_t packed = kq | (vq << 16); + kv_dq = dequantize_packed_int4(packed, k_shift_scale, v_shift_scale); + + } else { + __half2 k_shift_scale; + __half2 v_shift_scale; + int32_t group_size = D_H / KVQuantNumGroups; + int32_t group_idx = threadIdx.x * 4 / group_size; + + *reinterpret_cast(&k_shift_scale) = + *reinterpret_cast(&row_k[4 * group_idx]); + *reinterpret_cast(&v_shift_scale) = + *reinterpret_cast(&row_v[4 * group_idx]); + + int32_t int4_qparam_offset = 4 * KVQuantNumGroups; + + if (4 * threadIdx.x >= D_H) { + continue; + } + + uint32_t kq = *reinterpret_cast( + &row_k[threadIdx.x * 2 + int4_qparam_offset]); + uint32_t vq = *reinterpret_cast( + &row_v[threadIdx.x * 2 + int4_qparam_offset]); + + uint32_t packed = kq | (vq << 16); + + kv_dq = dequantize_packed_int4(packed, k_shift_scale, v_shift_scale); + } + // now, write our outputs + auto* row_k_dq = &cache_K_dq[b][t][h][0]; + auto* row_v_dq = &cache_V_dq[b][t][h][0]; + + *reinterpret_cast(&row_k_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[0]); + *reinterpret_cast(&row_v_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[2]); + } +} + +#define CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL(NUM_GROUPS, ...) \ + dequantize_int4_cache_kernel< \ + NUM_GROUPS><<>>( \ + cache_K.packed_accessor64(), \ + cache_V.packed_accessor64(), \ + kv_seqlen.packed_accessor32(), \ + cache_K_dq.packed_accessor64(), \ + cache_V_dq.packed_accessor64()); + +std::tuple dequantize_int4_cache( + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor kv_seqlen, + std::optional num_groups) { + // allocate DQ outputs + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + TORCH_CHECK(kv_seqlen.is_cuda()); + auto B = cache_K.size(0); + auto MAX_T = cache_K.size(1); + auto N_KVH = cache_K.size(2); + auto D_HQ = cache_K.size(3); + // D_HQ == D_H // 2 + 8 (int4 + 4xhalf qparams) + auto num_groups_ = num_groups ? num_groups.value() : 1; + auto int4_qparam_offset = 4 * num_groups_; + auto D_H = (D_HQ - int4_qparam_offset) * 2; + + auto cache_K_dq = + at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + auto cache_V_dq = + at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + + if (B == 0) { + return {cache_K_dq, cache_V_dq}; + } + + constexpr int32_t kMaxBlocks = 256; + dim3 blocks(B, std::max(1, kMaxBlocks / B)); + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK( + CALL_DEQUANTIZE_INT4_CACHE_GROUPWISE_KERNEL, num_groups_) + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return {cache_K_dq, cache_V_dq}; +} + +template +__device__ void get_dst_row( + T** dst_row, + at::PackedTensorAccessor64& + cache_KV, // [B][MAX_T][N_KVH][D_H +4 or D_H] + int32_t b, + int32_t h, + int32_t cache_loc_t, + int32_t page_size, + int32_t* block_tables, + int32_t block_tables_b_stride) { + if (block_tables == nullptr) { + *dst_row = &cache_KV[b][cache_loc_t][h][0]; + } else { + int page_logical_idx = cache_loc_t / page_size; + int page_offset = cache_loc_t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + *dst_row = &cache_KV[0][page_physical_idx * page_size + page_offset][h][0]; + } +} + +enum class PositionEmbeddingMode { ROPE = 0, XPOS = 1 }; +enum class QKV { Q, K, V }; +DEVICE_INLINE void quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q); + +template +__global__ void rope_xpos_qkv_varseq_prefill_kernel( + at::PackedTensorAccessor32 + XQ, // [B_T][N_H][D_H] + at::PackedTensorAccessor32 + XK, // [B_T][N_KVH][D_H] + at::PackedTensorAccessor32 + XV, // [B_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_K, // [B][MAX_T][N_KVH][D_H] or + // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention + at::PackedTensorAccessor64 + cache_V, // [B][MAX_T][N_KVH][D_H] or + // [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention + at::PackedTensorAccessor32 + XQ_O, // [B_T][N_H][D] + int32_t* varseq_batch, // in decoding case we have T == 1 and so just pass + // nullptr + at::PackedTensorAccessor32 varseq_seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + int32_t* block_tables, // [B][MAX_PAGES], maps logical pages to physical + // ones for paged attention + int32_t page_size, + int32_t block_tables_b_stride, + at::PackedTensorAccessor32 + varseq_cache_seqpos, + int64_t* actual_batch_size = + nullptr, // When running in CUDA graph mode, the actual batch size + // can be smaller than block_tables.size(0). In this case + // rows of block_tables beyond actual_batch_size are not + // initialized, and using them wil cause undefined + // behavior. To prevent this, when actual_batch_size is + // provided, the kernel exits if the current batch index is + // larger of equal to actual_batch_size, + bool rope_scaling = false, + int64_t old_context_len = 8192, + double scaling_factor = 16, + double lo_freq_factor = 1, + double hi_freq_factor = 32) { + // Launch b_t_(sum(h)) warps. + auto b_t_hh = blockIdx.x * blockDim.y + threadIdx.y; + auto B_T = XQ.size(0); + auto N_KVH = XK.size(1); + auto N_H = XQ.size(1); + auto D_H = XQ.size(2); + auto HH = 2 * N_KVH + N_H; + + auto hh = b_t_hh % HH; + auto b_t = b_t_hh / HH; + if (b_t >= B_T) { + return; + } + auto seqpos_t = varseq_seqpos[b_t]; + if (seqpos_t == -1) { + return; + } + auto cache_loc_t = varseq_cache_seqpos[b_t]; + auto b = varseq_batch ? varseq_batch[b_t] : b_t; + + if (actual_batch_size != nullptr && b_t >= *actual_batch_size) { + return; + } + + at::BFloat16* src_row; + at::BFloat16* dst_row; + auto h = 0; + QKV qkv; + if (hh < N_H) { + h = hh; + src_row = &XQ[b_t][h][0]; + dst_row = &XQ_O[b_t][h][0]; + qkv = QKV::Q; + } else if (hh < N_H + N_KVH) { + h = hh - N_H; + src_row = &XK[b_t][h][0]; + + get_dst_row( + &dst_row, + cache_K, + b, + h, + cache_loc_t, + page_size, + block_tables, + block_tables_b_stride); + qkv = QKV::K; + } else { + h = hh - N_H - N_KVH; + src_row = &XV[b_t][h][0]; + get_dst_row( + &dst_row, + cache_V, + b, + h, + cache_loc_t, + page_size, + block_tables, + block_tables_b_stride); + qkv = QKV::V; + } + + for (int32_t head_id = 4 * threadIdx.x; head_id < D_H; + head_id += kThreadsPerWarp * 4) { + // assert D_H % 4 == 0; + // load 4 elements per thread in a warp. + if (head_id >= D_H) { + return; + } + + bfx4 src; + *reinterpret_cast(&src) = + *reinterpret_cast(&src_row[head_id]); + if (qkv == QKV::V) { + *reinterpret_cast(&dst_row[head_id]) = + *reinterpret_cast(&src); + } else { + int32_t offset_0 = ((head_id) / 2 + 0); + int32_t offset_1 = ((head_id) / 2 + 1); + + double powers_0 = offset_0 * 2; + double powers_1 = offset_1 * 2; + + double freqs_0 = pow(theta, powers_0 / -static_cast(D_H)); + double freqs_1 = pow(theta, powers_1 / -static_cast(D_H)); + if (rope_scaling) { + double lo_freq_wavelen = old_context_len / lo_freq_factor; + double hi_freq_wavelen = old_context_len / hi_freq_factor; + double wavelen_0 = 2 * M_PI / freqs_0; + if (wavelen_0 >= hi_freq_wavelen && wavelen_0 > lo_freq_wavelen) { + freqs_0 = freqs_0 / scaling_factor; + } else if (wavelen_0 >= hi_freq_wavelen) { + double smooth = (old_context_len / wavelen_0 - lo_freq_factor) / + (hi_freq_factor - lo_freq_factor); + freqs_0 = (1 - smooth) * freqs_0 / scaling_factor + smooth * freqs_0; + } + double wavelen_1 = 2 * M_PI / freqs_1; + if (wavelen_1 >= hi_freq_wavelen && wavelen_1 > lo_freq_wavelen) { + freqs_1 = freqs_1 / scaling_factor; + } else if (wavelen_1 >= hi_freq_wavelen) { + double smooth = (old_context_len / wavelen_1 - lo_freq_factor) / + (hi_freq_factor - lo_freq_factor); + freqs_1 = (1 - smooth) * freqs_1 / scaling_factor + smooth * freqs_1; + } + } + freqs_0 = static_cast(seqpos_t) * freqs_0; + freqs_1 = static_cast(seqpos_t) * freqs_1; + + double sin_0, sin_1, cos_0, cos_1; + sincos(freqs_0, &sin_0, &cos_0); + sincos(freqs_1, &sin_1, &cos_1); + + auto src_0 = bf1622float2(src.vals[0]); + auto src_1 = bf1622float2(src.vals[1]); + + double dst_x, dst_y, dst_z, dst_w; + + dst_x = static_cast(src_0.x) * cos_0 - + static_cast(src_0.y) * sin_0; + dst_y = static_cast(src_0.y) * cos_0 + + static_cast(src_0.x) * sin_0; + + dst_z = static_cast(src_1.x) * cos_1 - + static_cast(src_1.y) * sin_1; + dst_w = static_cast(src_1.y) * cos_1 + + static_cast(src_1.x) * sin_1; + + if (Mode == PositionEmbeddingMode::XPOS) { + double gamma_0 = (powers_0 + gamma * D_H) / (D_H + gamma * D_H); + double gamma_1 = (powers_1 + gamma * D_H) / (D_H + gamma * D_H); + double scale_base_ = (qkv == QKV::Q) ? scale_base : -scale_base; + double factor_0 = pow( + gamma_0, + (static_cast(seqpos_t) - exponent_offset) / scale_base_); + double factor_1 = pow( + gamma_1, + (static_cast(seqpos_t) - exponent_offset) / scale_base_); + + dst_x *= factor_0; + dst_y *= factor_0; + dst_z *= factor_1; + dst_w *= factor_1; + } + + fx4 dst; + dst.x = __double2float_rn(dst_x); + dst.y = __double2float_rn(dst_y); + dst.z = __double2float_rn(dst_z); + dst.w = __double2float_rn(dst_w); + + bfx4 dst_; + dst_.vals[0] = __floats2bfloat162_rn(dst.x, dst.y); + dst_.vals[1] = __floats2bfloat162_rn(dst.z, dst.w); + *reinterpret_cast(&dst_row[head_id]) = + *reinterpret_cast(&dst_); + } + } +} + +template +DEVICE_INLINE fx4 rope_xpos( + bfx4 src, + int32_t seqpos_t, + QKV head, + double theta, + double gamma, + double scale_base, + double exponent_offset, + bool rope_scaling = false, + int64_t old_context_len = 8192, + double scaling_factor = 16, + double lo_freq_factor = 1, + double hi_freq_factor = 32) { + fx4 dst; // read 4 bf16 from src and store in 4 float registers + if (head == QKV::V) { + auto r0 = bf1622float2(src.vals[0]); + auto r1 = bf1622float2(src.vals[1]); + dst.x = r0.x; + dst.y = r0.y; + dst.z = r1.x; + dst.w = r1.y; + return dst; + } + int32_t offset_0 = ((4 * threadIdx.x) / 2 + 0); + int32_t offset_1 = ((4 * threadIdx.x) / 2 + 1); + + double powers_0 = offset_0 * 2; + double powers_1 = offset_1 * 2; + + double freqs_0 = pow(theta, powers_0 / -static_cast(D_H)); + double freqs_1 = pow(theta, powers_1 / -static_cast(D_H)); + + if (rope_scaling) { + // From https://github.com/fairinternal/llm_inference/pull/391 + // See https://arxiv.org/pdf/2309.16039 , https://fburl.com/eyhqrzhn + double lo_freq_wavelen = old_context_len / lo_freq_factor; + double hi_freq_wavelen = old_context_len / hi_freq_factor; + double wavelen_0 = 2 * M_PI / freqs_0; + if (wavelen_0 >= hi_freq_wavelen && wavelen_0 > lo_freq_wavelen) { + freqs_0 = freqs_0 / scaling_factor; + } else if (wavelen_0 >= hi_freq_wavelen) { + double smooth = (old_context_len / wavelen_0 - lo_freq_factor) / + (hi_freq_factor - lo_freq_factor); + freqs_0 = (1 - smooth) * freqs_0 / scaling_factor + smooth * freqs_0; + } + double wavelen_1 = 2 * M_PI / freqs_1; + if (wavelen_1 >= hi_freq_wavelen && wavelen_1 > lo_freq_wavelen) { + freqs_1 = freqs_1 / scaling_factor; + } else if (wavelen_1 >= hi_freq_wavelen) { + double smooth = (old_context_len / wavelen_1 - lo_freq_factor) / + (hi_freq_factor - lo_freq_factor); + freqs_1 = (1 - smooth) * freqs_1 / scaling_factor + smooth * freqs_1; + } + } + freqs_0 = static_cast(seqpos_t) * freqs_0; + freqs_1 = static_cast(seqpos_t) * freqs_1; + + double sin_0, sin_1, cos_0, cos_1; + sincos(freqs_0, &sin_0, &cos_0); + sincos(freqs_1, &sin_1, &cos_1); + + auto src_0 = bf1622float2(src.vals[0]); + auto src_1 = bf1622float2(src.vals[1]); + + double dst_x, dst_y, dst_z, dst_w; + + dst_x = static_cast(src_0.x) * cos_0 - + static_cast(src_0.y) * sin_0; + dst_y = static_cast(src_0.y) * cos_0 + + static_cast(src_0.x) * sin_0; + + dst_z = static_cast(src_1.x) * cos_1 - + static_cast(src_1.y) * sin_1; + dst_w = static_cast(src_1.y) * cos_1 + + static_cast(src_1.x) * sin_1; + + if (EmbMode == PositionEmbeddingMode::XPOS) { + double gamma_0 = (powers_0 + gamma * D_H) / (D_H + gamma * D_H); + double gamma_1 = (powers_1 + gamma * D_H) / (D_H + gamma * D_H); + double scale_base_ = (head == QKV::Q) ? scale_base : -scale_base; + double factor_0 = + pow(gamma_0, + (static_cast(seqpos_t) - exponent_offset) / scale_base_); + double factor_1 = + pow(gamma_1, + (static_cast(seqpos_t) - exponent_offset) / scale_base_); + dst_x *= factor_0; + dst_y *= factor_0; + dst_z *= factor_1; + dst_w *= factor_1; + } + + dst.x = __double2float_rn(dst_x); + dst.y = __double2float_rn(dst_y); + dst.z = __double2float_rn(dst_z); + dst.w = __double2float_rn(dst_w); + + return dst; +} + +template +DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) { + auto thread_min = fminf(fminf(fminf(dst.x, dst.y), dst.z), dst.w); + auto thread_max = fmaxf(fmaxf(fmaxf(dst.x, dst.y), dst.z), dst.w); + + float warp_min, warp_max; + + int32_t int4_qparam_offset = 4; + if (KVQuantNumGroups == 1) { + unsigned mask = ballot_sync(4 * threadIdx.x < D_H, 0xFFFFFFFF); + warp_min = -warpReduceMax(-thread_min, mask); + warp_max = warpReduceMax(thread_max, mask); + } else { + int32_t group_size = D_H / KVQuantNumGroups; + int32_t group_idx = threadIdx.x * 4 / group_size; + int4_qparam_offset = 4 * KVQuantNumGroups; + unsigned masks[KVQuantNumGroups]; + for (int i = 0; i < KVQuantNumGroups; ++i) { + masks[i] = ballot_sync(group_idx == i, 0xFFFFFFFF); + } + warp_min = -warpReduceMax(-thread_min, masks[group_idx]); + warp_max = warpReduceMax(thread_max, masks[group_idx]); + } + + auto scale = (warp_max - warp_min) / 15.0f; + auto inv_scale = 15.0 / (scale * 15.0 + 1.0e-8); + auto shift = warp_min; + + auto x_0 = __float2int_rn((dst.x - shift) * inv_scale) & 0xF; + auto x_1 = __float2int_rn((dst.y - shift) * inv_scale) & 0xF; + auto x_2 = __float2int_rn((dst.z - shift) * inv_scale) & 0xF; + auto x_3 = __float2int_rn((dst.w - shift) * inv_scale) & 0xF; + + uint16_t packed = 0; + + packed |= (x_0 << 0); + packed |= (x_1 << 4); + packed |= (x_2 << 8); + packed |= (x_3 << 12); + + // each threadIdx.x writes 2 bytes with 4+4 byte offset for scale/shift + + CUDA_KERNEL_ASSERT( + uintptr_t(&dst_row_q[2 * threadIdx.x + int4_qparam_offset]) % 2 == 0); + + *reinterpret_cast( + &dst_row_q[2 * threadIdx.x + int4_qparam_offset]) = packed; + if (threadIdx.x == 0) { + CUDA_KERNEL_ASSERT(uintptr_t(&dst_row_q[0]) % 4 == 0); + __half2 qparams = __floats2half2_rn(scale, shift); + *reinterpret_cast<__half2*>(&dst_row_q[0]) = qparams; + } + if (KVQuantNumGroups > 1) { + int32_t group_size = D_H / KVQuantNumGroups; + if (threadIdx.x > 0 && threadIdx.x * 4 % group_size == 0) { + int32_t group_idx = threadIdx.x * 4 / group_size; + int32_t qparam_offset = 4 * group_idx; + CUDA_KERNEL_ASSERT(uintptr_t(&dst_row_q[qparam_offset]) % 4 == 0); + __half2 qparams = __floats2half2_rn(scale, shift); + *reinterpret_cast<__half2*>(&dst_row_q[qparam_offset]) = qparams; + } + } +} + +#define CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( \ + NUM_GROUPS, \ + DTYPE, \ + EMB_MODE, \ + VARSEQ_BATCH, \ + VARSEQ_SEQPOS, \ + THETA, \ + GAMMA, \ + SCALE_BASE, \ + EXPO_OFFSET, \ + block_tables, \ + page_size, \ + block_tables_b_stride, \ + varseq_cache_seqpos, \ + actual_batch_size, \ + rope_scaling, \ + old_context_len, \ + scaling_factor, \ + lo_freq_factor, \ + hi_freq_factor) \ + rope_xpos_qkv_varseq_prefill_kernel_ \ + <<>>( \ + XQ.packed_accessor32(), \ + XK.packed_accessor32(), \ + XV.packed_accessor32(), \ + cache_K.packed_accessor64(), \ + cache_V.packed_accessor64(), \ + XQ_O.packed_accessor32(), \ + VARSEQ_BATCH, \ + VARSEQ_SEQPOS, \ + THETA, \ + GAMMA, \ + SCALE_BASE, \ + EXPO_OFFSET, \ + block_tables, \ + page_size, \ + block_tables_b_stride, \ + varseq_cache_seqpos, \ + actual_batch_size, \ + rope_scaling, \ + old_context_len, \ + scaling_factor, \ + lo_freq_factor, \ + hi_freq_factor); + +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) +class FP8_E4M3_MAX { + public: +#ifndef USE_ROCM + static constexpr float value = 448.0; +#else + static constexpr float value = 240.0; +#endif +}; +class FP8_E5M2_MAX { + public: + static constexpr float value = 57344.0; +}; +#endif + +template < + PositionEmbeddingMode EmbMode, + CacheLogicalDtype kCacheDtype, + int KVQuantNumGroups = 1> +__global__ void rope_xpos_qkv_varseq_prefill_kernel_( + at::PackedTensorAccessor32 + XQ, // [B_T][N_H][D_H] + at::PackedTensorAccessor32 + XK, // [B_T][N_KVH][D_H] + at::PackedTensorAccessor32 + XV, // [B_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_K, // [B][MAX_T][N_KVH][D_H +4] + at::PackedTensorAccessor64 + cache_V, // [B][MAX_T][N_KVH][D_H + 4] # What is G ? Should be D_H * + // (sizeof(uint8) // sizeof(fp8)) + at::PackedTensorAccessor32 + XQ_O, // [B_T][N_H][D] + int32_t* varseq_batch, // in decoding case we have T == 1 and so just + // pass nullptr + at::PackedTensorAccessor32 varseq_seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + int32_t* block_tables, // [B][MAX_PAGES], maps logical pages to physical + // ones for paged attention + int32_t page_size, + int32_t block_tables_b_stride, + at::PackedTensorAccessor32 + varseq_cache_seqpos, + int64_t* actual_batch_size = + nullptr, // When running in CUDA graph mode, the actual batch size + // can be smaller than block_tables.size(0). In this case + // rows of block_tables beyond actual_batch_size are not + // initialized, and using them wil cause undefined + // behavior. To prevent this, when actual_batch_size is + // provided, the kernel exits if the current batch index is + // larger of equal to actual_batch_size, + bool rope_scaling = false, + int64_t old_context_len = 8192, + double scaling_factor = 16, + double lo_freq_factor = 1, + double hi_freq_factor = 32) { + // Launch b_t_(sum(h)) warps. + auto b_t_hh = blockIdx.x * blockDim.y + + threadIdx.y; // Block = [kThreadsPerWarp, kWarpsPerBlock] + // Each warp handles a single head XQ or XK or XV of a single token.. + // That would be 1 x 128 distributed among 32 threads in the warp. + // Each thread should handle 4 elements. + auto B_T = XQ.size(0); + auto N_KVH = XK.size(1); + auto N_H = XQ.size(1); + auto D_H = XQ.size(2); + + auto HH = 2 * N_KVH + N_H; + + auto hh = b_t_hh % HH; + auto b_t = b_t_hh / HH; + if (b_t >= B_T) { + return; + } + auto seqpos_t = varseq_seqpos[b_t]; + if (seqpos_t == -1) { + return; + } + auto cache_loc_t = varseq_cache_seqpos[b_t]; + auto b = varseq_batch ? varseq_batch[b_t] : b_t; + + if (actual_batch_size != nullptr && b_t >= *actual_batch_size) { + return; + } + + at::BFloat16* src_row = nullptr; + at::BFloat16* dst_row = nullptr; + uint8_t* dst_row_q = nullptr; + auto h = 0; + QKV qkv; + if (hh < N_H) { + h = hh; + src_row = &XQ[b_t][h][0]; + dst_row = &XQ_O[b_t][h][0]; + qkv = QKV::Q; + } else if (hh < N_H + N_KVH) { + h = hh - N_H; + src_row = &XK[b_t][h][0]; + get_dst_row( + &dst_row_q, + cache_K, + b, + h, + cache_loc_t, + page_size, + block_tables, + block_tables_b_stride); + qkv = QKV::K; + } else { + h = hh - N_H - N_KVH; + src_row = &XV[b_t][h][0]; + get_dst_row( + &dst_row_q, + cache_V, + b, + h, + cache_loc_t, + page_size, + block_tables, + block_tables_b_stride); + qkv = QKV::V; + } + + // load 4 elements per thread in a warp. + + // Each thread should handle D_H//32 = 4 elements. + CUDA_KERNEL_ASSERT(D_H <= 4 * kThreadsPerWarp); + if (4 * threadIdx.x >= D_H) { + return; + } + bfx4 src; + *reinterpret_cast(&src) = + *reinterpret_cast(&src_row[4 * threadIdx.x]); + + fx4 dst = rope_xpos( + src, + seqpos_t, + qkv, + theta, + gamma, + scale_base, + exponent_offset, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + // now we have our output. + if (qkv == QKV::Q) { // is_q // store to Qo without quantization + bfx4 dst_ = fx4_to_bfx4(dst); + CUDA_KERNEL_ASSERT(uintptr_t(&dst_row[4 * threadIdx.x]) % 8 == 0); + + *reinterpret_cast(&dst_row[4 * threadIdx.x]) = + *reinterpret_cast(&dst_); + } else { + if (kCacheDtype == CacheLogicalDtype::FP8) { + // fp8 quantization + quantize_fp8_kv(dst, dst_row_q); + + } else if (kCacheDtype == CacheLogicalDtype::INT4) { + quantize_int4_kv(dst, dst_row_q); + } + } +} + +at::Tensor rope_qkv_varseq_prefill( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor varseq_batch, + at::Tensor varseq_seqpos, + double theta, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional varseq_cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling = false, + int64_t old_context_len = 8192, + double scaling_factor = 16, + double lo_freq_factor = 1, + double hi_freq_factor = 32) { + auto B_T = XQ.size(0); + auto N_H = XQ.size(1); + auto N_KVH = XK.size(1); + + TORCH_CHECK(XQ.size(2) % 4 == 0); + TORCH_CHECK(XQ.size(2) <= 512); + + int32_t num_warps = B_T * (2 * N_KVH + N_H); + TORCH_CHECK(num_warps > 0); + + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + dim3 blocks(cuda_calc_xblock_count(num_warps, kWarpsPerBlock)); + + TORCH_CHECK(varseq_batch.is_contiguous()); + TORCH_CHECK(varseq_batch.numel() == B_T); + auto XQ_O = at::empty_like(XQ); + + auto varseq_cache_seqpos_ = varseq_cache_seqpos.value_or(varseq_seqpos); + + CacheLogicalDtype cache_logical_dtype = + static_cast(cache_logical_dtype_int); + + int32_t* block_tables_ptr = nullptr; + int32_t block_tables_b_stride = 0; + if (block_tables.has_value()) { + block_tables_ptr = static_cast(block_tables.value().data_ptr()); + block_tables_b_stride = block_tables.value().stride(0); + } + if (cache_K.dtype() == at::kBFloat16) { + rope_xpos_qkv_varseq_prefill_kernel + <<>>( + XQ.packed_accessor32(), + XK.packed_accessor32(), + XV.packed_accessor32(), + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + XQ_O.packed_accessor32(), + varseq_batch.data_ptr(), + varseq_seqpos + .packed_accessor32(), + theta, + 0, + 0, + 0, + block_tables_ptr, + page_size, + block_tables_b_stride, + varseq_cache_seqpos_ + .packed_accessor32(), + nullptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + auto num_groups_ = num_groups ? num_groups.value() : 1; + auto varseq_batch_ = varseq_batch.data_ptr(); + auto varseq_seqpos_ = + varseq_seqpos.packed_accessor32(); + if (cache_logical_dtype == CacheLogicalDtype::FP8) { +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) + CUDA_KERNEL_ASSERT(num_groups_ == 1); + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( + 1, + CacheLogicalDtype::FP8, + PositionEmbeddingMode::ROPE, + varseq_batch_, + varseq_seqpos_, + theta, + 0, + 0, + 0, + block_tables_ptr, + page_size, + block_tables_b_stride, + (varseq_cache_seqpos_ + .packed_accessor32()), + nullptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#else + throw std::runtime_error("CUDA version is older than 12.0"); +#endif + } else { + CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK( + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL, + num_groups_, + CacheLogicalDtype::INT4, + PositionEmbeddingMode::ROPE, + varseq_batch_, + varseq_seqpos_, + theta, + 0, + 0, + 0, + block_tables_ptr, + page_size, + block_tables_b_stride, + (varseq_cache_seqpos_ + .packed_accessor32()), + nullptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + return XQ_O; +} + +at::Tensor xpos_qkv_varseq_prefill( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor varseq_batch, + at::Tensor varseq_seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional varseq_cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling = false, + int64_t old_context_len = 8192, + double scaling_factor = 16, + double lo_freq_factor = 1, + double hi_freq_factor = 32) { + auto B_T = XQ.size(0); + auto N_H = XQ.size(1); + auto N_KVH = XK.size(1); + + TORCH_CHECK(XQ.size(2) % 4 == 0); + TORCH_CHECK(XQ.size(2) <= 512); + + int32_t num_warps = B_T * (2 * N_KVH + N_H); + TORCH_CHECK(num_warps > 0); + + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + dim3 blocks(cuda_calc_xblock_count(num_warps, kWarpsPerBlock)); + + auto XQ_O = at::empty_like(XQ); + TORCH_CHECK(varseq_batch.is_contiguous()); + TORCH_CHECK(varseq_batch.numel() == B_T); + auto varseq_cache_seqpos_ = varseq_cache_seqpos.value_or(varseq_seqpos); + CacheLogicalDtype cache_logical_dtype = + static_cast(cache_logical_dtype_int); + + int32_t* block_tables_ptr = nullptr; + int32_t block_tables_b_stride = 0; + if (block_tables.has_value()) { + block_tables_ptr = static_cast(block_tables.value().data_ptr()); + block_tables_b_stride = block_tables.value().stride(0); + } + + if (cache_K.dtype() == at::kBFloat16) { + rope_xpos_qkv_varseq_prefill_kernel + <<>>( + XQ.packed_accessor32(), + XK.packed_accessor32(), + XV.packed_accessor32(), + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + XQ_O.packed_accessor32(), + varseq_batch.data_ptr(), + varseq_seqpos + .packed_accessor32(), + theta, + gamma, + scale_base, + exponent_offset, + block_tables_ptr, + page_size, + block_tables_b_stride, + varseq_cache_seqpos_ + .packed_accessor32(), + nullptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + auto num_groups_ = num_groups ? num_groups.value() : 1; + auto varseq_batch_ = varseq_batch.data_ptr(); + auto varseq_seqpos_ = + varseq_seqpos.packed_accessor32(); + if (cache_logical_dtype == CacheLogicalDtype::FP8) { +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) + CUDA_KERNEL_ASSERT(num_groups_ == 1); + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( + 1, + CacheLogicalDtype::FP8, + PositionEmbeddingMode::XPOS, + varseq_batch_, + varseq_seqpos_, + theta, + gamma, + scale_base, + exponent_offset, + block_tables_ptr, + page_size, + block_tables_b_stride, + (varseq_cache_seqpos_ + .packed_accessor32()), + nullptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#else + throw std::runtime_error("CUDA version is older than 12.0"); +#endif + } else { + CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK( + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL, + num_groups_, + CacheLogicalDtype::INT4, + PositionEmbeddingMode::XPOS, + varseq_batch_, + varseq_seqpos_, + theta, + gamma, + scale_base, + exponent_offset, + block_tables_ptr, + page_size, + block_tables_b_stride, + (varseq_cache_seqpos_ + .packed_accessor32()), + nullptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + return XQ_O; +} + +at::Tensor rope_qkv_decoding( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seqpos, + double theta, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional actual_batch_size, + std::optional batch, + std::optional cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling = false, + int64_t old_context_len = 8192, + double scaling_factor = 16, + double lo_freq_factor = 1, + double hi_freq_factor = 32) { + auto B = XQ.size(0); + auto N_H = XQ.size(1); + auto N_KVH = XK.size(1); + + TORCH_CHECK(XQ.size(2) % 4 == 0); + int32_t num_warps = B * (2 * N_KVH + N_H); + TORCH_CHECK(num_warps > 0); + + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + dim3 blocks(cuda_calc_xblock_count(num_warps, kWarpsPerBlock)); + auto XQ_O = at::empty_like(XQ); + + CacheLogicalDtype cache_logical_dtype = + static_cast(cache_logical_dtype_int); + + int32_t* block_tables_ptr = nullptr; + int32_t block_tables_b_stride = 0; + if (block_tables.has_value()) { + block_tables_ptr = static_cast(block_tables.value().data_ptr()); + block_tables_b_stride = block_tables.value().stride(0); + } + int64_t* actual_batch_size_ptr = nullptr; + if (actual_batch_size.has_value()) { + actual_batch_size_ptr = + static_cast(actual_batch_size.value().data_ptr()); + } + auto cache_seqpos_ = cache_seqpos.value_or(seqpos); + if (cache_K.dtype() == at::kBFloat16) { + rope_xpos_qkv_varseq_prefill_kernel + <<>>( + XQ.packed_accessor32(), + XK.packed_accessor32(), + XV.packed_accessor32(), + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + XQ_O.packed_accessor32(), + batch.has_value() ? batch.value().data_ptr() : nullptr, + seqpos.packed_accessor32(), + theta, + 0, + 0, + 0, + block_tables_ptr, + page_size, + block_tables_b_stride, + cache_seqpos_ + .packed_accessor32(), + actual_batch_size_ptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + auto seqpos_ = + seqpos.packed_accessor32(); + auto num_groups_ = num_groups ? num_groups.value() : 1; + if (cache_logical_dtype == CacheLogicalDtype::FP8) { +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) + CUDA_KERNEL_ASSERT(num_groups_ == 1); + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( + 1, + CacheLogicalDtype::FP8, + PositionEmbeddingMode::ROPE, + nullptr, + seqpos_, + theta, + 0, + 0, + 0, + block_tables_ptr, + page_size, + block_tables_b_stride, + (cache_seqpos_ + .packed_accessor32()), + actual_batch_size_ptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#else + throw std::runtime_error("CUDA version is older than 12.0"); +#endif + } else { + CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK( + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL, + num_groups_, + CacheLogicalDtype::INT4, + PositionEmbeddingMode::ROPE, + nullptr, + seqpos_, + theta, + 0, + 0, + 0, + block_tables_ptr, + page_size, + block_tables_b_stride, + (cache_seqpos_ + .packed_accessor32()), + actual_batch_size_ptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + + return XQ_O; +} + +at::Tensor xpos_qkv_decoding( + at::Tensor XQ, + at::Tensor XK, + at::Tensor XV, + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor seqpos, + double theta, + double gamma, + double scale_base, + double exponent_offset, + std::optional num_groups, + std::optional block_tables, + int64_t page_size, + std::optional actual_batch_size, + std::optional batch, + std::optional cache_seqpos, + int64_t cache_logical_dtype_int, + bool rope_scaling = false, + int64_t old_context_len = 8192, + double scaling_factor = 16, + double lo_freq_factor = 1, + double hi_freq_factor = 32) { + auto B = XQ.size(0); + auto N_H = XQ.size(1); + auto N_KVH = XK.size(1); + + TORCH_CHECK(XQ.size(2) % 4 == 0); + int32_t num_warps = B * (2 * N_KVH + N_H); + TORCH_CHECK(num_warps > 0); + + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + dim3 blocks(cuda_calc_xblock_count(num_warps, kWarpsPerBlock)); + auto XQ_O = at::empty_like(XQ); + CacheLogicalDtype cache_logical_dtype = + static_cast(cache_logical_dtype_int); + + int32_t* block_tables_ptr = nullptr; + int32_t block_tables_b_stride = 0; + if (block_tables.has_value()) { + block_tables_ptr = static_cast(block_tables.value().data_ptr()); + block_tables_b_stride = block_tables.value().stride(0); + } + + int64_t* actual_batch_size_ptr = nullptr; + if (actual_batch_size.has_value()) { + actual_batch_size_ptr = + static_cast(actual_batch_size.value().data_ptr()); + } + auto cache_seqpos_ = cache_seqpos.value_or(seqpos); + if (cache_K.dtype() == at::kBFloat16) { + rope_xpos_qkv_varseq_prefill_kernel + <<>>( + XQ.packed_accessor32(), + XK.packed_accessor32(), + XV.packed_accessor32(), + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + XQ_O.packed_accessor32(), + batch.has_value() ? batch.value().data_ptr() : nullptr, + seqpos.packed_accessor32(), + theta, + gamma, + scale_base, + exponent_offset, + block_tables_ptr, + page_size, + block_tables_b_stride, + cache_seqpos_ + .packed_accessor32(), + actual_batch_size_ptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + auto num_groups_ = num_groups ? num_groups.value() : 1; + auto seqpos_ = + seqpos.packed_accessor32(); + if (cache_logical_dtype == CacheLogicalDtype::FP8) { +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) + CUDA_KERNEL_ASSERT(num_groups_ == 1); + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL( + 1, + CacheLogicalDtype::FP8, + PositionEmbeddingMode::XPOS, + nullptr, + seqpos_, + theta, + gamma, + scale_base, + exponent_offset, + block_tables_ptr, + page_size, + block_tables_b_stride, + (cache_seqpos_ + .packed_accessor32()), + actual_batch_size_ptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#else + throw std::runtime_error("CUDA version is older than 12.0"); +#endif + } else { + CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK( + CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL, + num_groups_, + CacheLogicalDtype::INT4, + PositionEmbeddingMode::XPOS, + nullptr, + seqpos_, + theta, + gamma, + scale_base, + exponent_offset, + block_tables_ptr, + page_size, + block_tables_b_stride, + (cache_seqpos_ + .packed_accessor32()), + actual_batch_size_ptr, + rope_scaling, + old_context_len, + scaling_factor, + lo_freq_factor, + hi_freq_factor); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + return XQ_O; +} + +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) +__global__ void dequantize_fp8_cache_kernel( + // This code currently represents FP8 version not int4 + at::PackedTensorAccessor64 + cache_K, // [B][MAX_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V, // [B][MAX_T][N_KVH][D_H // G] + at::PackedTensorAccessor32 kv_seqlen, + at::PackedTensorAccessor64 + cache_K_dq, // [B][MAX_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V_dq // [B][MAX_T][N_KVH][D_H] +) { + auto N_KVH = cache_K.size(2); + auto MAX_T = cache_K.size(1); + auto D_H = cache_K_dq.size(3); + auto D_H_q = cache_K.size(3); + CUDA_KERNEL_ASSERT(D_H_q - D_H == 4); + + auto b = blockIdx.x; + // only need to dequantize this far. + auto max_t = kv_seqlen[b]; + + // one warp per T/H + for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; + t_h += blockDim.y * gridDim.y) { + auto h = t_h % N_KVH; + auto t = t_h / N_KVH; + + auto* row_k = &cache_K[b][t][h][0]; // uint8_t* + auto* row_v = &cache_V[b][t][h][0]; + bfx8 kv_dq; + __half2 k_shift_scale; + __half2 v_shift_scale; + *reinterpret_cast(&k_shift_scale) = + *reinterpret_cast(&row_k[0]); // reads 32 bits + *reinterpret_cast(&v_shift_scale) = + *reinterpret_cast(&row_v[0]); + if (4 * threadIdx.x >= D_H) { + continue; + } + // each thread reads 4 x 8 bits + + uint64_t kq = *reinterpret_cast(&row_k[threadIdx.x * 4 + 4]); + uint64_t vq = *reinterpret_cast(&row_v[threadIdx.x * 4 + 4]); + + uint64_t packed = kq | (vq << 32); + + kv_dq = dequantize_packed_fp8(packed, k_shift_scale, v_shift_scale); + + // now, write our outputs + auto* row_k_dq = &cache_K_dq[b][t][h][0]; + auto* row_v_dq = &cache_V_dq[b][t][h][0]; + // each thread writes 4 elements of type bf16 + *reinterpret_cast(&row_k_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[0]); + *reinterpret_cast(&row_v_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[2]); + } +} +std::tuple dequantize_fp8_cache( + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor kv_seqlen) { + TORCH_CHECK(cache_K.is_cuda()); + TORCH_CHECK(cache_V.is_cuda()); + TORCH_CHECK(kv_seqlen.is_cuda()); + auto B = cache_K.size(0); + auto MAX_T = cache_K.size(1); + auto N_KVH = cache_K.size(2); + auto D_HQ = cache_K.size(3); + auto num_groups = 1; + auto fp8_qparam_offset = num_groups * 4; + auto D_H = (D_HQ - fp8_qparam_offset); + + auto cache_K_dq = + at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + auto cache_V_dq = + at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + + if (B == 0) { + return {cache_K_dq, cache_V_dq}; + } + + constexpr int32_t kMaxBlocks = 256; + dim3 blocks(B, std::max(1, kMaxBlocks / B)); + dim3 threads(kThreadsPerWarp, kWarpsPerBlock); + dequantize_fp8_cache_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + kv_seqlen.packed_accessor32(), + cache_K_dq.packed_accessor64(), + cache_V_dq.packed_accessor64()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return {cache_K_dq, cache_V_dq}; +} + +DEVICE_INLINE void quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q) { + auto thread_min = fminf(fminf(fminf(dst.x, dst.y), dst.z), dst.w); + auto thread_max = fmaxf(fmaxf(fmaxf(dst.x, dst.y), dst.z), dst.w); + + float warp_min, warp_max; + + int32_t fp8_qparam_offset = 4; + unsigned mask = ballot_sync(4 * threadIdx.x < D_H, 0xFFFFFFFF); + warp_min = -warpReduceMax(-thread_min, mask); + warp_max = warpReduceMax(thread_max, mask); + + auto bounded_max = (warp_max - warp_min) / 2; + // TODO: Pass scale_ub + const float* scale_ub = nullptr; + constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX::value * 512.f); + if (scale_ub != nullptr) { + bounded_max = std::min(bounded_max, *scale_ub); + } + float scale = static_cast( + std::max(bounded_max / FP8_E4M3_MAX::value, min_scaling_factor)); + float inv_scale = 1 / scale; + float shift = warp_min + FP8_E4M3_MAX::value * scale; + + auto x_0 = __nv_fp8_e4m3((dst.x - shift) * inv_scale); + auto x_1 = __nv_fp8_e4m3((dst.y - shift) * inv_scale); + auto x_2 = __nv_fp8_e4m3((dst.z - shift) * inv_scale); + auto x_3 = __nv_fp8_e4m3((dst.w - shift) * inv_scale); + + uint32_t x_bits[4]; + x_bits[0] = *reinterpret_cast(&x_0); + x_bits[1] = *reinterpret_cast(&x_1); + x_bits[2] = *reinterpret_cast(&x_2); + x_bits[3] = *reinterpret_cast(&x_3); + + uint32_t packed = 0; + + packed |= (x_bits[0] << 0); + packed |= (x_bits[1] << 8); + packed |= (x_bits[2] << 16); + packed |= (x_bits[3] << 24); + + CUDA_KERNEL_ASSERT( + uintptr_t(&dst_row_q[4 * threadIdx.x + fp8_qparam_offset]) % 4 == 0); + + *reinterpret_cast( + &dst_row_q[4 * threadIdx.x + fp8_qparam_offset]) = packed; + if (threadIdx.x == 0) { + CUDA_KERNEL_ASSERT(uintptr_t(&dst_row_q[0]) % 4 == 0); + __half2 quant_params = __floats2half2_rn(scale, shift); + *reinterpret_cast<__half2*>(&dst_row_q[0]) = quant_params; + } +} +#else +DEVICE_INLINE void quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q) {} +std::vector quantize_fp8_per_tensor( + at::Tensor input, + std::optional bs, // batch size + std::optional scale_ub) { // scale upperbound + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +std::tuple dequantize_fp8_cache( + at::Tensor cache_K, + at::Tensor cache_V, + at::Tensor kv_seqlen) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} +#endif +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu index e461c43ad1..8cecc44af7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu @@ -44,11 +44,7 @@ #include #endif -#if ( \ - defined(__CUDA_ARCH__) && \ - ((__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900))) -#define USE_WMMA_FRAG -#endif +#include /// @defgroup FP8/INT8 quantized FC Operators /// @@ -69,49 +65,6 @@ namespace fbgemm_gpu { // each warp compute sum(t_subset) P[t] * V[t_subset, d] // outputs are of size float[D] -#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) - -static __host__ DEVICE_INLINE int32_t div_up(int32_t a, int32_t b) { - return (a + b - 1) / b; -}; - -static __host__ DEVICE_INLINE int32_t round_up(int32_t a, int32_t b) { - return ((a + b - 1) / b) * b; -} - -#ifdef __HIP_PLATFORM_AMD__ -#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) -constexpr int32_t kThreadsPerWarp = 64; -// constexpr int32_t kWarpsPerBlock = 16; -#endif -#else -constexpr int32_t kThreadsPerWarp = 32; -// constexpr int32_t kWarpsPerBlock = 32; -#endif - -// constexpr int32_t D_H = 128; -// MAX_T: max seq len. We need to make sure shared memory size -// (https://fburl.com/code/ruc41vc7) <= limit of V100/A100/H100 GPUs -// (https://fburl.com/code/gh9j9go4). -// constexpr int32_t MAX_T = 16384; -// constexpr int SMEM_ADJUST_THRESHOLD = 48 * 1024; - -#ifdef __HIP_PLATFORM_AMD__ -static __host__ __device__ float __bfloat162float(__nv_bfloat16 f) { - // float output; - // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/hip__bfloat16_8h_source.html - return float(f); -} - -static __host__ __device__ __nv_bfloat162 -__floats2bfloat162_rn(float x, float y) { - __nv_bfloat162 output; - output.x = __float2bfloat16_rn(x); - output.y = __float2bfloat16_rn(y); - return output; -} -#endif - #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) using __nv_fp8x4_e4m3 = __hip_fp8x4_e4m3_fnuz; using __nv_fp8_e4m3 = __hip_fp8_e4m3_fnuz; @@ -123,53 +76,6 @@ using __nv_fp8_e5m2 = __hip_fp8_e5m2_fnuz; #define torch_fp8_e5m2 at::kFloat8_e5m2 #endif -struct __align__(16) bf16x8 { - __nv_bfloat162 vals[4]; -}; - -struct __align__(16) fx4 { - float x; - float y; - float z; - float w; - __host__ __device__ fx4() { - x = 0; - y = 0; - z = 0; - w = 0; - } -}; - -struct __align__(8) bfx4 { - __nv_bfloat162 vals[2]; -}; - -struct __align__(16) bfx8 { - __nv_bfloat162 vals[4]; -}; - -DEVICE_INLINE bfx4 dequantize_packed_int4(uint16_t vs, __half2 shift_scale_0); -DEVICE_INLINE bfx8 dequantize_packed_int4( - uint32_t v, - __half2 shift_scale_0, - __half2 shift_scale_1); - -DEVICE_INLINE float2 bf1622float2(const __nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#elif defined(USE_ROCM) - float2 f_val; - f_val.x = __bfloat162float(val.x); - f_val.y = __bfloat162float(val.y); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - struct __align__(8) i8x8 { int8_t vals[8]; }; @@ -245,7 +151,8 @@ at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale) { constexpr int32_t kThreadsPerBlock = 1024; auto XQ = at::empty({X.numel()}, X.options().dtype(at::kChar)); dim3 threads = kThreadsPerBlock; - dim3 blocks = cuda_calc_block_count(div_up(X.numel(), 8), kThreadsPerBlock); + dim3 blocks = + cuda_calc_block_count(div_round_up(X.numel(), 8), kThreadsPerBlock); per_tensor_quantize_i8_kernel<<< blocks, threads, @@ -271,7 +178,8 @@ std::tuple per_tensor_dynamic_quantize_i8( .to(X.dtype()); dim3 threads = kThreadsPerBlock; - dim3 blocks = cuda_calc_block_count(div_up(X.numel(), 8), kThreadsPerBlock); + dim3 blocks = + cuda_calc_block_count(div_round_up(X.numel(), 8), kThreadsPerBlock); per_tensor_quantize_i8_kernel<<< blocks, @@ -653,29 +561,6 @@ void invokeQuantizeMatrixColwise( C10_CUDA_KERNEL_LAUNCH_CHECK(); } -#define FINAL_MASK 0xffffffff - -template -DEVICE_INLINE T shfl_xor( - unsigned shfl_sync_mask, - const T val, - int laneMask, - int width = kThreadsPerWarp) { -#if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION < 9000 - return __shfl_xor(val, laneMask, width); -#else - return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width); -#endif -} - -template -DEVICE_INLINE T warpReduceMax(T val, uint32_t warp_mask = FINAL_MASK) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = max(val, shfl_xor(warp_mask, val, mask, 32)); - return val; -} - template __inline__ __device__ T blockReduceMax(T val) { static __shared__ T shared[32]; diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py new file mode 100644 index 0000000000..17a58c6204 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py @@ -0,0 +1,578 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[56] + +import logging +import unittest +from enum import Enum, unique +from typing import List, Optional, Tuple + +import fbgemm_gpu.experimental.gen_ai # noqa: F401 +import torch +from hypothesis import given, settings, strategies as st + +try: + from xformers.attn_bias_utils import pack_kv_cache + from xformers.ops import fmha + + HAS_XFORMERS = True +except ImportError: + HAS_XFORMERS = False + +if HAS_XFORMERS: + from rope_padded import rope_padded + + +@unique +class LogicalDtype(Enum): + bf16 = 0 + fp8 = 1 + int4 = 2 + + +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def _get_varseq_batch_seqpos( + seqlens_q: List[int], seqlens_kv: List[int] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + varseq_batch[i] is batch index of query i + varseq_seqpos[i] is the offset of the last key which query i attends to + """ + + varseq_batch = torch.cat( + [ + torch.as_tensor([i for _ in range(len_q)], dtype=torch.int, device="cuda") + for i, len_q in enumerate(seqlens_q) + ] + ) + varseq_seqpos = torch.cat( + [ + torch.as_tensor( + [len_kv - len_q + t for t in range(len_q)], + dtype=torch.int, + device="cuda", + ) + for len_q, len_kv in zip(seqlens_q, seqlens_kv) + ] + ) + return varseq_batch, varseq_seqpos + + +class KVCacheTests(unittest.TestCase): + @settings(deadline=None) + @given( + num_groups=st.sampled_from([1, 2, 4, 8]), + MAX_T=st.sampled_from([8000, 16384]), + N_KVH_L=st.sampled_from([1, 2]), + ) + @unittest.skipIf( + not HAS_XFORMERS, + "Skip when xformers is not available", + ) + def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None: + N_H_L = 2 + T = 2 + B = 2 + D_H = 128 + # D = 8192 + # D_H = 128 + # B = 16 + # PROMPT_T = 1024 + + xq = ( + torch.randn(size=(B * T, N_H_L, D_H), dtype=torch.bfloat16, device="cuda") + * 0.01 + ) + xk = ( + torch.randn(size=(B * T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda") + * 0.01 + ) + xv = ( + torch.randn(size=(B * T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda") + * 0.01 + ) + varseq_seqpos = torch.cat( + [ + torch.as_tensor(list(range(T)), dtype=torch.int, device="cuda") + for b in range(B) + ] + ) + varseq_batch = torch.cat( + [ + torch.as_tensor([b for _ in range(T)], dtype=torch.int, device="cuda") + for b in range(B) + ] + ) + attn_bias = ( + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[T for _ in range(B)], + kv_padding=MAX_T, + kv_seqlen=[T for _ in range(B)], + ) + ) + attn_bias.k_seqinfo.to(torch.device("cuda")) + assert attn_bias.k_seqinfo.seqlen.shape == (B,) + assert attn_bias.k_seqinfo.seqlen.tolist() == [T for _ in range(B)] + + theta = 10000.0 + cache_k_bf16 = torch.zeros( + size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + cache_v_bf16 = torch.zeros( + size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + + xq_out_bf16 = torch.ops.fbgemm.rope_qkv_varseq_prefill( + xq, + xk, + xv, + cache_k_bf16, + cache_v_bf16, + varseq_batch, + varseq_seqpos, + theta, + ) + qparam_offset = 4 * num_groups + + cache_k_int4 = torch.zeros( + size=(B, MAX_T, N_KVH_L, int(D_H // 2) + qparam_offset), + dtype=torch.uint8, + device="cuda", + ) + cache_v_int4 = torch.zeros( + size=(B, MAX_T, N_KVH_L, int(D_H // 2) + qparam_offset), + dtype=torch.uint8, + device="cuda", + ) + xq_out = torch.ops.fbgemm.rope_qkv_varseq_prefill( + xq, + xk, + xv, + cache_k_int4, + cache_v_int4, + varseq_batch, + varseq_seqpos, + theta, + num_groups=num_groups, + cache_logical_dtype_int=LogicalDtype.int4.value, + ) + torch.testing.assert_close(xq_out_bf16, xq_out) + + cache_k, cache_v = torch.ops.fbgemm.dequantize_int4_cache( + cache_k_int4, + cache_v_int4, + attn_bias.k_seqinfo.seqlen, + num_groups=num_groups, + ) + + torch.testing.assert_close( + cache_k[:, :T], cache_k_bf16[:, :T], atol=1.0e-2, rtol=1.0e-2 + ) + torch.testing.assert_close( + cache_v[:, :T], cache_v_bf16[:, :T], atol=1.0e-2, rtol=1.0e-2 + ) + + @settings(deadline=None) + @given( + MAX_T=st.sampled_from([8000, 16384]), + N_KVH_L=st.sampled_from([1, 2]), + ) + @unittest.skipIf( + not torch.cuda.is_available() + or ( + torch.version.cuda + and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9 + ) + or (torch.version.hip and torch.version.hip < "6.2") + or not HAS_XFORMERS, + "Skip when H100 is not available or MI300 is not available", + ) + def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None: + N_H_L = 2 + T = 2 + B = 2 + D_H = 128 + + xq = ( + torch.cat( + [ + torch.randn(N_H_L, D_H, dtype=torch.bfloat16, device="cuda") * (i) + for i in range(B * T) + ] + ) + ).view(B * T, N_H_L, D_H) + scale_step = 0.01 / B / T + shift_step = 5 * scale_step + xk_rows = [ + scale_step + * (i + 1) + * torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda") + + i * shift_step + for i in range(B * T) + ] + xv_rows = [ + scale_step + * (i + 1) + * torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda") + + i * shift_step + for i in range(B * T) + ] + + xk = (torch.cat(xk_rows)).view(B * T, N_KVH_L, D_H) + + xv = (torch.cat(xv_rows)).view(B * T, N_KVH_L, D_H) + varseq_seqpos = torch.cat( + [ + torch.as_tensor(list(range(T)), dtype=torch.int, device="cuda") + for b in range(B) + ] + ) + varseq_batch = torch.cat( + [ + torch.as_tensor([b for _ in range(T)], dtype=torch.int, device="cuda") + for b in range(B) + ] + ) + attn_bias = ( + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[T for _ in range(B)], + kv_padding=MAX_T, + kv_seqlen=[T for _ in range(B)], + ) + ) + attn_bias.k_seqinfo.to(torch.device("cuda")) + assert attn_bias.k_seqinfo.seqlen.shape == (B,) + assert attn_bias.k_seqinfo.seqlen.tolist() == [T for _ in range(B)] + + theta = 10000.0 + cache_k_bf16 = torch.zeros( + size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + cache_v_bf16 = torch.zeros( + size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + + xq_out_bf16 = torch.ops.fbgemm.rope_qkv_varseq_prefill( + xq, + xk, + xv, + cache_k_bf16, + cache_v_bf16, + varseq_batch, + varseq_seqpos, + theta, + ) + qparam_offset = 4 + + cache_k_fp8 = torch.zeros( + size=(B, MAX_T, N_KVH_L, int(D_H) + qparam_offset), + dtype=torch.uint8, + device="cuda", + ) + cache_v_fp8 = torch.zeros( + size=(B, MAX_T, N_KVH_L, int(D_H) + qparam_offset), + dtype=torch.uint8, + device="cuda", + ) + xq_out = torch.ops.fbgemm.rope_qkv_varseq_prefill( + xq, + xk, + xv, + cache_k_fp8, + cache_v_fp8, + varseq_batch, + varseq_seqpos, + theta, + cache_logical_dtype_int=LogicalDtype.fp8.value, + ) + torch.testing.assert_close(xq_out_bf16, xq_out) + + cache_k, cache_v = torch.ops.fbgemm.dequantize_fp8_cache( + cache_k_fp8, + cache_v_fp8, + attn_bias.k_seqinfo.seqlen, + ) + + torch.testing.assert_close( + cache_k[:, :T], cache_k_bf16[:, :T], atol=1.0e-2, rtol=5.0e-2 + ) + torch.testing.assert_close( + cache_v[:, :T], cache_v_bf16[:, :T], atol=1.0e-2, rtol=5.0e-2 + ) + + @settings(deadline=None) + @given( + prefill=st.booleans(), + rope_theta=st.sampled_from([None, 10000.0]), + MAX_T=st.sampled_from([4000, 8192]), + B=st.sampled_from([1, 128]), + BLOCK_N=st.sampled_from([64, 128, 256]), + ) + @unittest.skipIf( + not HAS_XFORMERS, + "Skip when xformers is not available", + ) + def test_positional_encoding_with_paged_attention( + self, + prefill: bool, + rope_theta: Optional[float], + MAX_T: int, + B: int, + BLOCK_N: int, + ) -> None: + + N_H_L = 1 + N_KVH_L = 8 + D_H = 128 + torch.manual_seed(100) + + kv_seqlens = torch.randint(low=0, high=MAX_T, size=(B,)).tolist() + q_seqlens = kv_seqlens if prefill else [1 for _ in range(B)] + seq_positions = torch.tensor( + [x - 1 for x in kv_seqlens], device="cuda", dtype=torch.int32 + ) + total_length_q = sum(q_seqlens) + + cache_k = torch.randn( + (B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + cache_v = torch.randn_like(cache_k) + + block_tables, packed_cache_k, packed_cache_v = pack_kv_cache( + cache_k, + cache_v, + [x + 1 for x in seq_positions], + BLOCK_N=BLOCK_N, + ) + + assert packed_cache_k.is_contiguous() + assert packed_cache_v.is_contiguous() + + xqkv = torch.randn( + total_length_q, + N_H_L + 2 * N_KVH_L, + D_H, + dtype=torch.bfloat16, + device="cuda", + ) + xq = xqkv[:, :N_H_L, :] + xk = xqkv[:, N_H_L : N_H_L + N_KVH_L, :] + xv = xqkv[:, N_H_L + N_KVH_L :, :] + + xpos_gamma: float = 0.8 + xpos_scale_base: float = 4096.0 + xpos_theta: float = 500000.0 + xpos_exponent_offset = 0 + + assert cache_k.is_contiguous() + assert cache_v.is_contiguous() + + B_T = total_length_q + assert xq.shape == (B_T, N_H_L, D_H) + assert xk.shape == (B_T, N_KVH_L, D_H) + assert xv.shape == (B_T, N_KVH_L, D_H) + + assert cache_k.shape == (B, MAX_T, N_KVH_L, D_H) + assert cache_v.shape == (B, MAX_T, N_KVH_L, D_H) + + if prefill: + seqpos_args = _get_varseq_batch_seqpos(q_seqlens, kv_seqlens) + else: + seqpos_args = (seq_positions,) + + if rope_theta is not None: + func = ( + torch.ops.fbgemm.rope_qkv_varseq_prefill + if prefill + else torch.ops.fbgemm.rope_qkv_decoding + ) + xq_out_ref = func( + xq, + xk, + xv, + cache_k, + cache_v, + *seqpos_args, + rope_theta, + num_groups=0, + ) + xq_out_paged = func( + xq, + xk, + xv, + packed_cache_k, + packed_cache_v, + *seqpos_args, + rope_theta, + num_groups=0, + block_tables=block_tables, + page_size=BLOCK_N, + ) + else: + func = ( + torch.ops.fbgemm.xpos_qkv_varseq_prefill + if prefill + else torch.ops.fbgemm.xpos_qkv_decoding + ) + xq_out_ref = func( + xq, + xk, + xv, + cache_k, + cache_v, + *seqpos_args, + theta=xpos_theta, + gamma=xpos_gamma, + scale_base=xpos_scale_base, + exponent_offset=xpos_exponent_offset, + num_groups=0, + ) + xq_out_paged = func( + xq, + xk, + xv, + packed_cache_k, + packed_cache_v, + *seqpos_args, + xpos_theta, + xpos_gamma, + xpos_scale_base, + xpos_exponent_offset, + num_groups=0, + block_tables=block_tables, + page_size=BLOCK_N, + ) + torch.testing.assert_close(xq_out_ref, xq_out_paged) + + for b in range(B): + num_blocks = (kv_seqlens[b] + BLOCK_N - 1) // BLOCK_N + for logical_idx in range(num_blocks): + len_to_compare = min(kv_seqlens[b] - logical_idx * BLOCK_N, BLOCK_N) + for kv_ref, kv_packed in ( + (cache_k, packed_cache_k), + (cache_v, packed_cache_v), + ): + physical_idx = block_tables[b][logical_idx] + logical_start = logical_idx * BLOCK_N + physical_start = physical_idx * BLOCK_N + ref_vals = kv_ref[ + b, + logical_start : logical_start + len_to_compare, + ] + packed_vals = kv_packed[0][ + physical_start : physical_start + len_to_compare + ] + torch.testing.assert_close(ref_vals, packed_vals) + + @settings(deadline=None) + @given( + prefill=st.booleans(), + rope_theta=st.sampled_from([10000.0]), + MAX_T=st.sampled_from([8192]), + B=st.sampled_from([128]), + BLOCK_N=st.sampled_from([256]), + ) + @unittest.skipIf( + not HAS_XFORMERS, + "Skip when xformers is not available", + ) + def test_rope_positional_encoding_only( + self, + prefill: bool, + rope_theta: float, + MAX_T: int, + B: int, + BLOCK_N: int, + ) -> None: + N_H_L = 1 + N_KVH_L = 8 + D_H = 128 + torch.manual_seed(100) + + kv_seqlens = torch.randint(low=0, high=MAX_T, size=(B,)).tolist() + q_seqlens = kv_seqlens if prefill else [1 for _ in range(B)] + seq_positions = torch.tensor( + [x - 1 for x in kv_seqlens], device="cuda", dtype=torch.int32 + ) + total_length_q = sum(q_seqlens) + + cache_k = torch.randn( + (B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" + ) + cache_v = torch.randn_like(cache_k) + + xqkv = torch.randn( + total_length_q, + N_H_L + 2 * N_KVH_L, + D_H, + dtype=torch.bfloat16, + device="cuda", + ) + xq = xqkv[:, :N_H_L, :] + xk = xqkv[:, N_H_L : N_H_L + N_KVH_L, :] + xv = xqkv[:, N_H_L + N_KVH_L :, :] + + assert cache_k.is_contiguous() + assert cache_v.is_contiguous() + + B_T = total_length_q + assert xq.shape == (B_T, N_H_L, D_H) + assert xk.shape == (B_T, N_KVH_L, D_H) + assert xv.shape == (B_T, N_KVH_L, D_H) + + assert cache_k.shape == (B, MAX_T, N_KVH_L, D_H) + assert cache_v.shape == (B, MAX_T, N_KVH_L, D_H) + + if prefill: + seqpos_args = _get_varseq_batch_seqpos(q_seqlens, kv_seqlens) + else: + seqpos_args = (seq_positions,) + + func = ( + torch.ops.fbgemm.rope_qkv_varseq_prefill + if prefill + else torch.ops.fbgemm.rope_qkv_decoding + ) + xq_out = func( + xq, + xk, + xv, + cache_k, + cache_v, + *seqpos_args, + rope_theta, + num_groups=0, + ) + xq_out = xq_out.view(1, xq_out.shape[0], xq_out.shape[1], xq_out.shape[2]) + attn_bias = ( + fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=q_seqlens, + kv_padding=MAX_T, + kv_seqlen=kv_seqlens, + ) + ) + attn_bias.k_seqinfo.to(torch.device("cuda")) + xq = xq.view(1, xq.shape[0], N_H_L, D_H) + xk = xk.view(1, xk.shape[0], N_KVH_L, D_H) + xv = xv.view(1, xv.shape[0], N_KVH_L, D_H) + cache_k = cache_k.view(1, B * MAX_T, N_KVH_L, D_H) + cache_v = cache_k.view(1, B * MAX_T, N_KVH_L, D_H) + xq_out_ref = rope_padded( + xq=xq, + xk=xk, + xv=xv, + cache_k=cache_k, + cache_v=cache_v, + attn_bias=attn_bias, + theta=rope_theta, + ) + + torch.testing.assert_close(xq_out, xq_out_ref, atol=0.01, rtol=0.01) diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py new file mode 100644 index 0000000000..ab728db6f2 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict, Optional + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +try: + from xformers.ops.fmha.attn_bias import ( # type: ignore + BlockDiagonalCausalWithOffsetPaddedKeysMask, + ) + + HAS_XFORMERS = True +except ImportError: + HAS_XFORMERS = False + from typing import Any + + BlockDiagonalCausalWithOffsetPaddedKeysMask = Any + +try: + # @manual=//triton:triton + # pyre-fixme[21]: Could not find module `triton.language.libdevice`. + from triton.language.libdevice import pow +except ImportError: + try: + # @manual=//triton:triton + # pyre-fixme[21]: Could not find name `pow` in `triton.language.math`. + from triton.language.math import pow + except ImportError: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import pow + + +_INTERNAL_DTYPE_MAP: Dict[str, int] = {"": 0, "f32": 1, "f64": 2} + + +@triton.jit +def _rope_padded_kernel( + xq, + xk, + xv, + out_q, + cache_k, + cache_v, + seqstartq, + seqstartk, + seqlenk, + theta, + k_start: tl.constexpr, + v_start: tl.constexpr, + dim: tl.constexpr, # dimension of each head + stride_xqM, + stride_xqH, + stride_xkM, + stride_xkH, + stride_xvM, + stride_xvH, + stride_cachekM, + stride_cachekH, + stride_cachevM, + stride_cachevH, + stride_seqstartq, + stride_seqstartk, + stride_seqlenk, + stride_outqM, + stride_outqH, + internal_dtype: tl.constexpr, + # If True, seqstartq and seqstartk are not used but rather we + # assume that every batch element has the same number of + # queries (i.e. num_queries := tl.num_programs(1) ) + # and the same cache space cache_padding_length. + # Always False when called below. + const_batch_strides: tl.constexpr, + # If const_batch_strides==True, the common cache length for each batch element. + # (Only the first seqlenk[i] elements are actually in use, and only the last + # num_queries of those are actually written to.) + cache_padding_length, + # offset added to all values in seqlenk before using them. + # Always 0 when called below. + seqlenk_shift: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + adjacents: tl.constexpr, +): + """ + Each letter in this diagram is a whole row of length dim. + + INPUT xq xk xv + + head_dim ─► + + batch qqqqqq kk vv + │ qqqqqq kk vv + ▼ qqqqqq kk vv + + head_idx: (goes across all heads of all 3 inputs) + ▲ ▲ ▲ ▲ ▲ ▲ + │ │ │ │ │ │ + │ │ + 0 k_start │v_start │n_total_heads + │ │ + │ │ + k_start v_start + + Output is to out_q (same shape as xq), an xk-shaped part + of cache_k and an xv-shaped part of cache_v + """ + batch_elt = tl.program_id(0) + query_pos_in_batch_elt = tl.program_id(1) + head_idx = tl.program_id(2) + + if internal_dtype == 1: + theta = theta.to(tl.float32) + elif internal_dtype == 2: + theta = theta.to(tl.float64) + + if const_batch_strides: + query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt + end_query_pos = tl.num_programs(1) * (batch_elt + 1) + else: + query_pos = query_pos_in_batch_elt + tl.load( + seqstartq + batch_elt * stride_seqstartq + ) + end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq) + if query_pos >= end_query_pos: + return + + is_q = head_idx < k_start + is_v = head_idx >= v_start + + xq += query_pos * stride_xqM + head_idx * stride_xqH + out_q += query_pos * stride_outqM + head_idx * stride_outqH + + if const_batch_strides: + cache_start = cache_padding_length * batch_elt + else: + cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk) + end_of_batch_elt_cache = ( + cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift + ) + + cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos) + seq_pos = cache_pos - cache_start + cache_k += (head_idx - k_start) * stride_cachekH + cache_pos * stride_cachekM + xk += query_pos * stride_xkM + (head_idx - k_start) * stride_xkH + in_qk = tl.where(is_q, xq, xk) + out_qk = tl.where(is_q, out_q, cache_k) + + cache_v += (head_idx - v_start) * stride_cachevH + cache_pos * stride_cachevM + xv += query_pos * stride_xvM + (head_idx - v_start) * stride_xvH + + out = tl.where(is_v, cache_v, out_qk) + x_in = tl.where(is_v, xv, in_qk) + + for offset in range(0, dim // 2, BLOCK_SIZE // 2): + c = tl.arange(0, BLOCK_SIZE // 2) + powers = (offset + c) * 2.0 + if adjacents: + cols_re = (offset + c) * 2 + cols_im = cols_re + 1 + else: + cols_re = offset + c + cols_im = cols_re + dim // 2 + + mask = cols_im < dim + + re_x = tl.load(x_in + cols_re, mask=mask) + im_x = tl.load(x_in + cols_im, mask=mask) + # freqs = seq_pos / (theta ** (powers / dim)) + # pyre-fixme[16]: Module `language` has no attribute `libdevice`. + freqs = seq_pos * pow(theta, powers / (-dim)) + sines = tl.sin(freqs) + cosines = tl.cos(freqs) + re_out = re_x * cosines - im_x * sines + im_out = im_x * cosines + re_x * sines + + re_out_ = tl.where(is_v, re_x, re_out) + im_out_ = tl.where(is_v, im_x, im_out) + if internal_dtype == 2: + if re_x.dtype == tl.bfloat16: + # triton 2.0.0 crashes if you try to convert + # float64 directly to bfloat16, so make an intermediate step. + re_out_ = re_out_.to(tl.float32) + im_out_ = im_out_.to(tl.float32) + tl.store(out + cols_re, re_out_, mask=mask) + tl.store(out + cols_im, im_out_, mask=mask) + + +def rope_padded( + xq, + xk, + xv, + cache_k, + cache_v, + attn_bias: BlockDiagonalCausalWithOffsetPaddedKeysMask, + *, + theta: float = 10000.0, + out_q: Optional[torch.Tensor] = None, + adjacents: bool = True, + internal_dtype: str = "", +): + """ + Applies rope to a heterogeneous batch in the style given + by xformers' BlockDiagonalCausalWithOffsetPaddedKeysMask. + The batch is concatted along the sequence dimension, so the + actual xformers batch size needs to be 1. + + xq, xk and xv should be (1, slen, n_heads, dim), where xq's n_heads can differ from xk and xv + + This function places the roped xk in the right place in cache_k and + xv (unmodified) in the right place in cache_v, and returns out_q + such that things are ready to call + xformers.ops.memory_efficient_attention(out_q, cache_k, cache_v, attn_bias=attn_bias) + + WARNING: This function relies on private details of xformers. + + Arguments: + xq: tensor of queries to apply rope to + xk: tensor of keys to apply rope to + xv: tensor of values to copy into cache_v + cache_k: cache of keys, modified in place + cache_v: cache of values, modified in place + attn_bias: details the layout of caches. + Used to determine frequencies for the + RoPE calculation as well as the locations in cache_k and cache_v + to write to. Must be on the device. + adjacents: If True, the inputs are in adjacent pairs along the final dim axis. + This is like the released LLaMA model and xlformers. + If False, the dim axis is split in two equal pieces. + I.e. the features are ordered with all the real parts before all + the imaginary parts. This matches HuggingFace right now. + https://github.com/huggingface/transformers/blob/ + f143037789288ba532dada934a118e648e715738/ + src/transformers/models/llama/modeling_llama.py#L126-L130 + internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation + """ + n_total_queries = attn_bias.q_seqinfo.seqstart_py[-1] + cache_length = attn_bias.k_seqinfo.seqstart_py[-1] + assert xq.shape[1] == n_total_queries + bsz, _, n_q_heads, dim = xq.shape + assert bsz == 1 + n_kv_heads = xk.shape[2] + assert xk.shape == (1, n_total_queries, n_kv_heads, dim) + assert xv.shape == (1, n_total_queries, n_kv_heads, dim) + assert cache_k.shape == (1, cache_length, n_kv_heads, dim) + assert cache_v.shape == (1, cache_length, n_kv_heads, dim) + assert xq.stride(3) == 1 + assert xk.stride(3) == 1 + assert xv.stride(3) == 1 + assert cache_k.stride(3) == 1 + assert cache_v.stride(3) == 1 + n_total_heads = n_q_heads + 2 * n_kv_heads + v_start = n_total_heads - n_kv_heads + k_start = n_q_heads + if out_q is None: + out_q = xq.new_empty(1, n_total_queries, n_q_heads, dim) + else: + assert out_q.shape == xq.shape + assert out_q.stride(3) == 1 + assert out_q is not None + + logical_bsz = len(attn_bias.q_seqinfo.seqstart_py) - 1 + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // xq.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(dim)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + device = xq.device + # Move these to the right device, like fmha does. + attn_bias.k_seqinfo.to(device) + attn_bias.q_seqinfo.to(device) + seqstartq = attn_bias.q_seqinfo.seqstart + seqstartk = attn_bias.k_seqinfo.seqstart + seqlenk = attn_bias.k_seqinfo.seqlen + assert internal_dtype in ["", "f32", "f64"] + # experiment with the order of dims here. + with torch.cuda.device(xq.device.index): + # pyre-fixme[28]: Unexpected keyword argument `num_warps`. + _rope_padded_kernel[ + (logical_bsz, attn_bias.q_seqinfo.max_seqlen, n_total_heads) + ]( + xq, + xk, + xv, + out_q, + cache_k, + cache_v, + seqstartq, + seqstartk, + seqlenk, + theta, + k_start, + v_start, + dim, + xq.stride(1), + xq.stride(2), + xk.stride(1), + xk.stride(2), + xv.stride(1), + xv.stride(2), + cache_k.stride(1), + cache_k.stride(2), + cache_v.stride(1), + cache_v.stride(2), + seqstartq.stride(0), + seqstartk.stride(0), + seqlenk.stride(0), + out_q.stride(1), + out_q.stride(2), + _INTERNAL_DTYPE_MAP[internal_dtype], + const_batch_strides=False, + cache_padding_length=0, + seqlenk_shift=0, + BLOCK_SIZE=BLOCK_SIZE, + adjacents=adjacents, + num_warps=num_warps, + ) + return out_q diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 59b0773568..da4dfc7919 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -159,6 +159,10 @@ __host__ DEVICE_INLINE int32_t div_round_up(int32_t a, int32_t b) { return (a + b - 1) / b; } +static __host__ DEVICE_INLINE int32_t round_up(int32_t a, int32_t b) { + return ((a + b - 1) / b) * b; +} + __host__ DEVICE_INLINE int32_t round_down(int32_t a, int32_t b) { return a / b * b; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh new file mode 100644 index 0000000000..d61285b7ff --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh @@ -0,0 +1,431 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "fbgemm_gpu/utils/cuda_prelude.cuh" +#include "fbgemm_gpu/utils/float.cuh" +#include "fbgemm_gpu/utils/types.h" + +#if !( \ + defined(USE_ROCM) || \ + ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#include +#elif (defined(USE_ROCM)) +#include +#include +#endif + +#if CUDART_VERSION >= 12000 +#include +#elif (defined(USE_ROCM) && ROCM_VERSION >= 60200) +#include +#endif + +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) +using __nv_fp8_e4m3 = __hip_fp8_e4m3_fnuz; +#endif + +namespace fbgemm_gpu { + +#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) + +#ifdef __HIP_PLATFORM_AMD__ +// #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) +constexpr int32_t kThreadsPerWarp = 64; +constexpr int32_t kWarpsPerBlock = 16; +// #endif +#else +constexpr int32_t kThreadsPerWarp = 32; +constexpr int32_t kWarpsPerBlock = 32; +#endif + +constexpr int32_t D_H = 128; + +#ifdef __HIP_PLATFORM_AMD__ + +using __nv_bfloat16 = hip_bfloat16; + +static __host__ __device__ float __bfloat162float(__nv_bfloat16 f) { + // float output; + // https://docs.amd.com/projects/HIP/en/docs-5.0.0/doxygen/html/hip__bfloat16_8h_source.html + return float(f); +} + +static __host__ __device__ __nv_bfloat162 +__floats2bfloat162_rn(float x, float y) { + __nv_bfloat162 output; + output.x = __float2bfloat16_rn(x); + output.y = __float2bfloat16_rn(y); + return output; +} + +#endif + +struct __align__(16) bf16x8 { + __nv_bfloat162 vals[4]; +}; + +struct __align__(16) fx4 { + float x; + float y; + float z; + float w; + __host__ __device__ fx4() { + x = 0; + y = 0; + z = 0; + w = 0; + } +}; + +struct __align__(8) bfx4 { + __nv_bfloat162 vals[2]; +}; + +struct __align__(16) bfx8 { + __nv_bfloat162 vals[4]; +}; +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) +DEVICE_INLINE bfx4 dequantize_packed_fp8(uint32_t vs, __half2 shift_scale_0); +#endif +DEVICE_INLINE bfx4 dequantize_packed_int4(uint16_t vs, __half2 shift_scale_0); +DEVICE_INLINE bfx8 dequantize_packed_int4( + uint32_t v, + __half2 shift_scale_0, + __half2 shift_scale_1); + +DEVICE_INLINE float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#elif defined(USE_ROCM) + float2 f_val; + f_val.x = __bfloat162float(val.x); + f_val.y = __bfloat162float(val.y); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +#define CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(NAME, NUM_GROUPS, ...) \ + switch (NUM_GROUPS) { \ + case 1: \ + NAME(1, __VA_ARGS__); \ + break; \ + case 2: \ + NAME(2, __VA_ARGS__); \ + break; \ + case 4: \ + NAME(4, __VA_ARGS__); \ + break; \ + case 8: \ + NAME(8, __VA_ARGS__); \ + break; \ + case 16: \ + TORCH_CHECK( \ + false, \ + "With head dim = 128 we're almost even with int8 at this point. Are you sure about this? Num groups:", \ + NUM_GROUPS); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported number of groups: ", NUM_GROUPS); \ + } + +DEVICE_INLINE float bfx4_dot(bfx4 a, bfx4 b) { + // float2 acc = {0, 0}; + // __nv_bfloat162 acc; + // acc.x = static_cast(0); + // acc.y = static_cast(0); + // TODO: need to be performed in float32? + auto a0 = bf1622float2(a.vals[0]); + auto a1 = bf1622float2(a.vals[1]); + auto b0 = bf1622float2(b.vals[0]); + auto b1 = bf1622float2(b.vals[1]); + return a0.x * b0.x + a0.y * b0.y + a1.x * b1.x + a1.y * b1.y; + + // acc = __hfma2(a.vals[0], b.vals[0], acc); + // acc = __hfma2(a.vals[1], b.vals[1], acc); + // auto r = bf1622float2(acc); + // return r.x + r.y; +} + +DEVICE_INLINE fx4 bfx4_scale_acc(fx4 acc, bfx4 a, float b) { + auto axy = bf1622float2(a.vals[0]); + auto azw = bf1622float2(a.vals[1]); + acc.x += axy.x * b; + acc.y += axy.y * b; + acc.z += azw.x * b; + acc.w += azw.y * b; + return acc; +} + +DEVICE_INLINE fx4 fx4_acc(fx4 a, fx4 b) { + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; + return a; +} + +DEVICE_INLINE bfx4 fx4_to_bfx4(fx4 a) { + bfx4 r; + r.vals[0] = __floats2bfloat162_rn(a.x, a.y); + r.vals[1] = __floats2bfloat162_rn(a.z, a.w); + return r; +} + +#define FINAL_MASK 0xffffffff + +template +DEVICE_INLINE T shfl_xor( + unsigned shfl_sync_mask, + const T val, + int laneMask, + int width = kThreadsPerWarp) { +#if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION < 9000 + return __shfl_xor(val, laneMask, width); +#else + return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width); +#endif +} + +template +DEVICE_INLINE T warpReduceSum(T val, uint32_t warp_mask = FINAL_MASK) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += shfl_xor(warp_mask, val, mask, 32); + return val; +} + +template +DEVICE_INLINE T warpReduceMax(T val, uint32_t warp_mask = FINAL_MASK) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, shfl_xor(warp_mask, val, mask, 32)); + return val; +} + +struct __align__(8) halfx4 { + __half2 vals[2]; +}; + +struct __align__(16) halfx8 { + __half2 vals[4]; +}; + +DEVICE_INLINE bfx4 dequantize_packed_int4(uint16_t vs, __half2 shift_scale_0) { + uint32_t v = vs; + // move 2nd byte to 3rd byte, so our bits are in 0x00FF00FF positions. + v = (v & 0xFF) | ((v & 0xFF00) << 8); + + halfx4 res; + res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + + // ~5% perf gain is observed with the explicit type conversions using + // __float2half on Nvidia A100 GPUs (https://fburl.com/diff/ss8372zw) using + // NVCC 11.0. Additionally, HIP compiler requires these explicit type + // conversions. + half shift_scale_0_x = __low2half(shift_scale_0); + half shift_scale_0_y = __high2half(shift_scale_0); + + // now, dequantize + auto shifts = __half2(shift_scale_0_y, shift_scale_0_y); + auto scales_lower = __half2( + __hmul(shift_scale_0_x, __float2half(512)), + __hmul(shift_scale_0_x, __float2half(512))); + auto scales_upper = __half2( + __hmul(shift_scale_0_x, __float2half(32)), + __hmul(shift_scale_0_x, __float2half(32))); + + auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); + auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); + + bfx4 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); + result.vals[1] = __floats2bfloat162_rn(r0.y, r1.y); + return result; +} + +DEVICE_INLINE bfx8 dequantize_packed_int4( + uint32_t v, + __half2 shift_scale_0, + __half2 shift_scale_1) { + halfx8 res; + res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + v >>= 8; + res.vals[2] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[3] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + + half shift_scale_0_x = __low2half(shift_scale_0); + half shift_scale_0_y = __high2half(shift_scale_0); + half shift_scale_1_x = __low2half(shift_scale_1); + half shift_scale_1_y = __high2half(shift_scale_1); + + // now, dequantize + auto shifts = __half2(shift_scale_0_y, shift_scale_1_y); + auto scales_lower = __half2( + __hmul(shift_scale_0_x, __float2half(512)), + __hmul(shift_scale_1_x, __float2half(512))); + auto scales_upper = __half2( + __hmul(shift_scale_0_x, __float2half(32)), + __hmul(shift_scale_1_x, __float2half(32))); + + auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); + auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); + auto r2 = __half22float2(__hfma2(res.vals[2], scales_lower, shifts)); + auto r3 = __half22float2(__hfma2(res.vals[3], scales_upper, shifts)); + + bfx8 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); + result.vals[1] = __floats2bfloat162_rn(r2.x, r3.x); + result.vals[2] = __floats2bfloat162_rn(r0.y, r1.y); + result.vals[3] = __floats2bfloat162_rn(r2.y, r3.y); + return result; +} + +__forceinline__ __device__ bfx8 +dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { + halfx8 res; + uint32_t v = packedVals; + // What's going on here, you might ask? We extra out 4-bit pairs of integers + // as 2xuint16 packed into an int32 via the mask operation, and then we + // convert them to half precision values. As these are all integers in [0, + // 15], we can actually just interpret the 4-bit integer values as + // half-precision values. We multiply by 4096 x 4096 to go from the 4-bit + // representation to the equivalent fp16 value, or alternatively 32768 * 512 + // (or 32 when we have shifted the 4-bit value up). See e.g. + // https://gist.github.com/ajtulloch/021254a291a95966bc509db4e34ffeff for a + // NumPy implementation. We do this dance because: a) doing bitwise operations + // on each 4-bit value is expensive on the ALU, and 4-bit to half is expensive + // on the XU. b) doing a 256-entry shared memory LUT on 8-bit pairs is + // expensive on SMEM throughput. Credit to @jhj. + res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + v >>= 8; + res.vals[2] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[3] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + + // ~5% perf gain is observed with the explicit type conversions using + // __float2half on Nvidia A100 GPUs (https://fburl.com/diff/ss8372zw) using + // NVCC 11.0. Additionally, HIP compiler requires these explicit type + // conversions. + half shift_scale_x = __low2half(shift_scale); + half shift_scale_y = __high2half(shift_scale); + + // now, dequantize + auto shifts = __half2(shift_scale_y, shift_scale_y); + auto scales_lower_temp = __hmul(shift_scale_x, __float2half(512)); + auto scales_lower = __half2(scales_lower_temp, scales_lower_temp); + auto scales_upper_temp = __hmul(shift_scale_x, __float2half(32)); + auto scales_upper = __half2(scales_upper_temp, scales_upper_temp); + + auto r0 = __half22float2(__hfma2(res.vals[0], scales_lower, shifts)); + auto r1 = __half22float2(__hfma2(res.vals[1], scales_upper, shifts)); + auto r2 = __half22float2(__hfma2(res.vals[2], scales_lower, shifts)); + auto r3 = __half22float2(__hfma2(res.vals[3], scales_upper, shifts)); + + bfx8 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r1.x); + result.vals[1] = __floats2bfloat162_rn(r2.x, r3.x); + result.vals[2] = __floats2bfloat162_rn(r0.y, r1.y); + result.vals[3] = __floats2bfloat162_rn(r2.y, r3.y); + + return result; +} + +enum class CacheLogicalDtype { BF16, FP8, INT4 }; + +#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \ + (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) +DEVICE_INLINE bfx8 dequantize_packed_fp8_symmetric( + uint64_t v, // Vq1 Vq0 Kq1 Kq0 + float scale_0, // k scale + float scale_1) { // v scale + uint32_t k_ = v & 0xFFFFFFFF; // 32 LSB + __nv_fp8_e4m3* fp8_k = reinterpret_cast<__nv_fp8_e4m3*>(&k_); + v >>= 32; + uint32_t v_ = v & 0xFFFFFFFF; + __nv_fp8_e4m3* fp8_v = reinterpret_cast<__nv_fp8_e4m3*>(&v_); + + // now, dequantize + auto r0 = make_float2(float(fp8_k[0]) * scale_0, float(fp8_k[1]) * scale_0); + auto r1 = make_float2(float(fp8_k[2]) * scale_0, float(fp8_k[3]) * scale_0); + auto r2 = make_float2(float(fp8_v[0]) * scale_1, float(fp8_v[1]) * scale_1); + auto r3 = make_float2(float(fp8_v[2]) * scale_1, float(fp8_v[3]) * scale_1); + + bfx8 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r0.y); // (k0_dq, k1_dq) + result.vals[1] = __floats2bfloat162_rn(r1.x, r1.y); + result.vals[2] = __floats2bfloat162_rn(r2.x, r2.y); // (v0_dq, v1_dq) + result.vals[3] = __floats2bfloat162_rn(r3.x, r3.y); + return result; +} +DEVICE_INLINE bfx4 dequantize_packed_fp8(uint32_t vs, __half2 shift_scale_0) { + uint32_t v = vs; + __nv_fp8_e4m3* fp8_k = reinterpret_cast<__nv_fp8_e4m3*>(&v); // 4 element + + auto shift_0 = __half2float(__high2half(shift_scale_0)); + auto scale_0 = __half2float(__low2half(shift_scale_0)); + + // now, dequantize + auto r0 = make_float2( + float(fp8_k[0]) * scale_0 + shift_0, float(fp8_k[1]) * scale_0 + shift_0); + auto r1 = make_float2( + float(fp8_k[2]) * scale_0 + shift_0, float(fp8_k[3]) * scale_0 + shift_0); + + bfx4 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r0.y); + result.vals[1] = __floats2bfloat162_rn(r1.x, r1.y); + return result; +} +DEVICE_INLINE bfx8 dequantize_packed_fp8( + uint64_t v, // Vq1 Vq0 Kq1 Kq0 + __half2 shift_scale_k, + __half2 shift_scale_v) { + uint32_t k_ = v & 0xFFFFFFFF; // 32 LSB + __nv_fp8_e4m3* fp8_k = reinterpret_cast<__nv_fp8_e4m3*>(&k_); + v >>= 32; + uint32_t v_ = v & 0xFFFFFFFF; + __nv_fp8_e4m3* fp8_v = reinterpret_cast<__nv_fp8_e4m3*>(&v_); + + auto shift_0 = __half2float(__high2half(shift_scale_k)); + auto scale_0 = __half2float(__low2half(shift_scale_k)); + auto shift_1 = __half2float(__high2half(shift_scale_v)); + auto scale_1 = __half2float(__low2half(shift_scale_v)); + + // now, dequantize + auto r0 = make_float2( + float(fp8_k[0]) * scale_0 + shift_0, float(fp8_k[1]) * scale_0 + shift_0); + auto r1 = make_float2( + float(fp8_k[2]) * scale_0 + shift_0, float(fp8_k[3]) * scale_0 + shift_0); + auto r2 = make_float2( + float(fp8_v[0]) * scale_1 + shift_1, float(fp8_v[1]) * scale_1 + shift_1); + auto r3 = make_float2( + float(fp8_v[2]) * scale_1 + shift_1, float(fp8_v[3]) * scale_1 + shift_1); + + bfx8 result; + result.vals[0] = __floats2bfloat162_rn(r0.x, r0.y); // (k0_dq, k1_dq) + result.vals[1] = __floats2bfloat162_rn(r1.x, r1.y); + result.vals[2] = __floats2bfloat162_rn(r2.x, r2.y); // (v0_dq, v1_dq) + result.vals[3] = __floats2bfloat162_rn(r3.x, r3.y); + return result; +} +#endif + +} // namespace fbgemm_gpu