diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2aee262..8be87fb 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -91,7 +91,7 @@ jobs: # Install torch $cudaVersion = $env:CUDA_VERSION.Replace('.', '') $cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1) - if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.2.0" } else {$pytorchVersion = "torch==2.0.1"} + $pytorchVersion = "torch==2.2.1" python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch python -m pip install build setuptools wheel ninja diff --git a/awq_ext/pybind_awq_v2.cpp b/awq_ext/pybind_awq_v2.cpp new file mode 100644 index 0000000..9499e8b --- /dev/null +++ b/awq_ext/pybind_awq_v2.cpp @@ -0,0 +1,10 @@ +#include +#include +#include "quantization_new/gemm/gemm_cuda.h" +#include "quantization_new/gemv/gemv_cuda.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("gemm_forward_cuda_prefill", &gemm_forward_cuda_prefill, "New quantized GEMM kernel."); + m.def("gemv_forward_cuda_decode", &gemv_forward_cuda_decode, "New quantized GEMM kernel."); +} \ No newline at end of file diff --git a/awq_ext/quantization_new/dequantize.cuh b/awq_ext/quantization_new/dequantize.cuh new file mode 100644 index 0000000..fa02fb7 --- /dev/null +++ b/awq_ext/quantization_new/dequantize.cuh @@ -0,0 +1,77 @@ +/* +Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} +*/ +#include +#pragma once + +__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) +{ + // uint4 result; + + uint32_t *h = reinterpret_cast(result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + // return result; +} \ No newline at end of file diff --git a/awq_ext/quantization_new/gemm/gemm_cuda.cu b/awq_ext/quantization_new/gemm/gemm_cuda.cu new file mode 100644 index 0000000..b9c8c1f --- /dev/null +++ b/awq_ext/quantization_new/gemm/gemm_cuda.cu @@ -0,0 +1,1033 @@ +#include +#include "semaphore.h" +#include "gemm_cuda.h" +#include "../dequantize.cuh" +#include +#include + +#define kInterleave 4 +#define OP_M 16 +#define OP_N 8 +#define OP_K 16 +#define INTRIN_M 16 +#define INTRIN_N 16 +#define INTRIN_K 16 +#define WARP_SIZE 32 +#define SMEM_PAD_A 0 +#define SMEM_PAD_B 0 +#define PACK_SIZE 8 +#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) +#define L2_CACHEHINT(size) ".L2::" #size "B" +#else +#define L2_CACHEHINT(size) +#endif + +#define KERNEL_LAUNCH_CODE \ + int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \ + torch::Tensor _semaphores = torch::empty({num_mn_tiles}, options_int); \ + auto semaphores = reinterpret_cast(_semaphores.data_ptr()); \ + constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ + constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \ + constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(half); \ + if (kSmemByteSize >= 99 * 1024) \ + { \ + printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \ + return _out_feats; \ + } \ + int j_factors1 = num_out_channels / CTA_N / 1; \ + dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \ + dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ + auto kernel_func = gemm_w4a16_T1; \ + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ + kernel_func<<>>( \ + in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels); + +template +__inline__ __host__ __device__ int get_log_tile(int n) +{ + if (N >= 8 && n >= 6) + return 3; + else if (N >= 4 && n >= 3) + return 2; + else if (N >= 2 && n >= 2) + return 1; + else + return 0; +} + +__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) +{ + return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1))); +} + +template +__device__ void sync_slice(int slice_id) +{ + if constexpr (SLICES == 1) + { + __syncthreads(); + } + else + { + constexpr int SLICE_GROUP = (SLICES + 7) / 8; + constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE; + const uint32_t barrier_id = slice_id / SLICE_GROUP + 1; + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads)); + } +} + +__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) +{ + uint32_t smem_int_ptr; + + asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_int_ptr) + : "l"(ptr)); + + return smem_int_ptr; +} + +__inline__ __device__ void ldmatrix_m8n8_x4_b16(half *shared_warp, int ax0_0, uint32_t addr) +{ + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) + : "r"(addr)); +} + +__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0_0, uint32_t addr) +{ + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];" + : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) + : "r"(addr)); +} + +__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) +{ + const int cp_size = 16; + asm volatile("{" + " .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;" + "}" ::"r"((int)mask), + "r"(smem_int_ptr), + "l"(src), + "n"(cp_size)); +} + +__device__ __inline__ void mma_m16n8k16(float *C_warp, half *A_shared_warp, half *B_shared_warp) +{ + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" + : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3]) + : "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])); +} + +template +__device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) +{ + constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used; + constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; + constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; + constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int threads_per_row = CTA_K / PACK_SIZE; + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); + int ld_col = (threadIdx.x % threads_per_row); +#pragma unroll + for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) + { + int global_iter = shared_iter_k * partial_global_iters + _global_iter; + int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); + int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE; + void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled); + uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE); + if constexpr (STAGES > 1) + { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows)); + } + else + { + if (local_mask & (ld_row + cta_offset_m < global_nrows)) + *(uint4 *)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) +{ + constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used; + constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; + constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; + constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int threads_per_row = CTA_K / PACK_SIZE; + constexpr int kSmemCol = CTA_K + SMEM_PAD_B; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); +#pragma unroll + for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) + { + int global_iter = shared_iter_k * partial_global_iters + _global_iter; + + int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); + int ld_col = (threadIdx.x % threads_per_row); + int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7; + void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE)); + uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k); + if constexpr (STAGES > 1) + { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask); + } + else + { + if (local_mask) + *(uint4 *)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) +{ + constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G; + constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used; + constexpr int threads_per_row = CTA_N / PACK_SIZE; + constexpr int kSmemCol = CTA_N; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); + int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G; + + void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE); + uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE); + void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE); + uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE); + if (STAGES > 1) + { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask); + uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z); + cp_async_cg_A(addr_z, src_ptr_z, local_mask); + } + else + { + if (local_mask) + { + *(uint4 *)dst_ptr = *src_ptr; + *(uint4 *)dst_ptr_z = *src_ptr_z; + } + } +} + +template +__device__ __inline__ void share_to_reg_one_stage_A(half *src, half *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) +{ + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) + { + + int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16); + int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k; + int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE; + void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled); + + uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); + ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); + } +} + +template +__device__ __inline__ void share_to_reg_one_stage_B(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) +{ + constexpr int kSmemCol = CTA_K + SMEM_PAD_B; + int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); + int c0 = ((threadIdx.x / 8) % 2) * 8; + int r = r0 / 4; + int c = (r0 % 4) * 16 + c0; + int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE; + + if constexpr (ldmatrix) + { +#pragma unroll + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) + { + void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k); + uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); + ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); + } + } + +#pragma unroll + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) + { + half scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + half zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + half2 scale2 = make_half2(scale, scale); + half2 zero2 = make_half2(zero, zero); + half2 loaded[4]; + + dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); +#pragma unroll + for (int i = 0; i < 4; i++) + { + loaded[i] = __hfma2(loaded[i], scale2, zero2); + } + *reinterpret_cast(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast(loaded); + } +} + +template +__global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K) +{ + constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; + constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K; + constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; + constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE; + constexpr int SLICES = CTA_K / WARP_K; + int num_blocks_n = (N + CTA_N - 1) / CTA_N; + int num_blocks_m = (M + CTA_M - 1) / CTA_M; + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n); + int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n); + const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N); + int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile); + int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile); + const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile); + blockIdx_m = block_idx_mapping.x; + blockIdx_n = block_idx_mapping.y; + + float C_warp[CTA_M * CTA_N / CTA_SIZE_MN]; + constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; + constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; + constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; + constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB; + constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; + constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; + constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1; + constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1; + constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load; + constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load; + extern __shared__ half mem_shared[]; + half *A_shared = mem_shared; + half *B_shared = mem_shared + kSmemSizeA; + half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB; + half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; + float *C_shared = reinterpret_cast(mem_shared); + half A_shared_warp_[2][WARP_M * INTRIN_K / + WARP_SIZE]; + half B_shared_warp_[2][WARP_N * 32 / + WARP_SIZE]; + half B_shared_warp_tmp_[2][WARP_N * 16 / + WARP_SIZE]; + int cta_offset_m = blockIdx_m * CTA_M; + int cta_offset_n = blockIdx_n * CTA_N; + int cta_offset_k = blockIdx_z * (K / SPLITK); + int warp_mn = threadIdx.y % NUM_WARPS_MN; + int slice_id = threadIdx.y / NUM_WARPS_MN; + int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N; + int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M; + int warp_offset_k = slice_id * WARP_K; + + for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++) + C_warp[i] = 0.0; + + int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK; + int k_0_0_ld = 0; + int k_0_0 = 0; + constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; +#pragma unroll + for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) + { + global_to_share_one_stage_A(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true); + global_to_share_one_stage_B(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true); + global_to_share_one_stage_scales( + scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N, + zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N, + N, cta_offset_m, cta_offset_n, cta_offset_k, + k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); + if constexpr (STAGES > 1) + __pipeline_commit(); + } + if constexpr (STAGES > 1) + __pipeline_wait_prior(STAGES - 2); + __syncthreads(); + + share_to_reg_one_stage_A(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0); + share_to_reg_one_stage_B(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0); + constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K; + + for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) + { + int ld_stage = k_0_0_ld % STAGES; + int compute_stage = k_0_0 % STAGES; + half *A_shared_this_compute_stage; + half *B_shared_this_compute_stage; + half *scales_shared_this_compute_stage; + half *zeros_shared_this_compute_stage; + +#pragma unroll + for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) + { + A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage; + B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage; + scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N; + zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N; + share_to_reg_one_stage_A(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); + if ((iter_k + 1) % kInterleave == 0) + { + if (compute_stage % 2 == 1) + { + share_to_reg_one_stage_B( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); + } + else + { + share_to_reg_one_stage_B( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); + } + } + else + { + if (compute_stage % 2 == 1) + { + share_to_reg_one_stage_B( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); + } + else + { + share_to_reg_one_stage_B( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); + } + } + half *A_shared_warp = A_shared_warp_[iter_k % 2]; + half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; + + for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) + { + for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) + { + mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } + } + + if (iter_k < WARP_K / INTRIN_K - 1) + { + if constexpr (STAGES == 1) + __syncthreads(); + global_to_share_one_stage_A(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); + global_to_share_one_stage_B(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); + } + + if (iter_k == WARP_K / INTRIN_K - 2) + { + if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) + { + __syncthreads(); + } + global_to_share_one_stage_A(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); + global_to_share_one_stage_B(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); + global_to_share_one_stage_scales( + scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N, + zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N, + N, cta_offset_m, cta_offset_n, cta_offset_k, + k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); + if constexpr (STAGES > 1) + { + __pipeline_commit(); + __pipeline_wait_prior(STAGES - 2); + } + compute_stage = (k_0_0 + 1) % STAGES; + __syncthreads(); + } + } + } + __pipeline_commit(); + __pipeline_wait_prior(0); + __syncthreads(); + if constexpr (SLICES > 1) + { +#pragma unroll + for (int z = 0; z < SLICES; ++z) + { + if (slice_id == z) + { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + { + if (z > 0) + { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; + } + C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; + }; + } + } + } + __syncthreads(); + } + if (slice_id == 0) + { +#pragma unroll + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { +#pragma unroll + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { +#pragma unroll + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) + { + C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; + }; + } + } + } + } + + if (slice_id == 0) + { + Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x); + + if constexpr (SPLITK > 1) + { + semaphore.fetch(); + } + + if (blockIdx_z != 0) + { + semaphore.wait(blockIdx_z); + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + + if (write_row < M) + { + half2 *existing_psum_ptr = reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2); + + *existing_psum_ptr = __hadd2(*existing_psum_ptr, + __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id))); + } + }; + } + } + } + else + { + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + if (write_row < M) + { + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id)); + } + }; + } + } + } + + if constexpr (SPLITK > 1) + { + + int lock = 0; + if (SPLITK == blockIdx_z + 1) + { + + lock = 0; + } + else + { + lock = blockIdx_z + 1; + } + semaphore.release(lock); + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +{ + constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used; + constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; + constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; + constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int threads_per_row = CTA_K / PACK_SIZE; + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); + int ld_col = (threadIdx.x % threads_per_row); +#pragma unroll + for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) + { + int global_iter = shared_iter_k * partial_global_iters + _global_iter; + int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); + int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE; + void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled); + uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE); + if constexpr (STAGES > 1) + { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows)); + } + else + { + if (local_mask & (ld_row + cta_offset_m < global_nrows)) + *(uint4 *)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +{ + constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used; + constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; + constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; + constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; + constexpr int threads_per_row = CTA_K / PACK_SIZE; + constexpr int kSmemCol = CTA_K + SMEM_PAD_B; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); +#pragma unroll + for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) + { + int global_iter = shared_iter_k * partial_global_iters + _global_iter; + + int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); + int ld_col = (threadIdx.x % threads_per_row); + int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7; + void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE)); + uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE); + if constexpr (STAGES > 1) + { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask); + } + else + { + if (local_mask) + *(uint4 *)dst_ptr = *src_ptr; + } + } +} + +template +__device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) +{ + constexpr int threads_needed = CTA_N / PACK_SIZE / 1; + constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; + constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used; + constexpr int threads_per_row = CTA_N / PACK_SIZE; + constexpr int kSmemCol = CTA_N; + bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); + int g_idx = global_iter_k * CTA_K / G; + + void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE); + uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); + void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE); + uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); + if (STAGES > 1) + { + uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); + cp_async_cg_A(addr, src_ptr, local_mask); + uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z); + cp_async_cg_A(addr_z, src_ptr_z, local_mask); + } + else + { + if (local_mask) + { + *(uint4 *)dst_ptr = *src_ptr; + *(uint4 *)dst_ptr_z = *src_ptr_z; + } + } +} + +template +__device__ __inline__ void share_to_reg_one_stage_A_T2(half *src, half *dst, int warp_offset_m, int warp_offset_n, int k_0_1) +{ + constexpr int kSmemCol = CTA_K + SMEM_PAD_A; + + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) + { + + int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16); + int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8; + int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE; + void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled); + + uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); + ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); + } +} + +template +__device__ __inline__ void share_to_reg_one_stage_B_T2(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1) +{ + constexpr int kSmemCol = CTA_K + SMEM_PAD_B; + int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); + int c0 = ((threadIdx.x / 8) % 2) * 8; + int r = r0 / 4; + int c = (r0 % 4) * 16 + c0; + int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE; + + if constexpr (ldmatrix) + { +#pragma unroll + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) + { + void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled); + uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); + ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); + } + } + +#pragma unroll + for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) + { + half scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + half zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; + half2 scale2 = make_half2(scale, scale); + half2 zero2 = make_half2(zero, zero); + half2 loaded[4]; + dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); +#pragma unroll + for (int i = 0; i < 4; i++) + { + loaded[i] = __hfma2(loaded[i], scale2, zero2); + } + *reinterpret_cast(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast(loaded); + } +} + +template +__global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int M, int N, int K) +{ + constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N; + constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; + int num_blocks_n = (N + CTA_N - 1) / CTA_N; + int num_blocks_m = (M + CTA_M - 1) / CTA_M; + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n); + int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n); + const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N); + int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile); + int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile); + const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile); + blockIdx_m = block_idx_mapping.x; + blockIdx_n = block_idx_mapping.y; + + float C_warp[CTA_M * CTA_N / CTA_SIZE]; + constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; + constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; + constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; + constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB; + constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; + constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; + constexpr int kSmemSizeScales = CTA_N * STAGES / 2; + constexpr int kSmemSizeZeros = CTA_N * STAGES / 2; + constexpr int scales_load_interval = G / CTA_K; + extern __shared__ half mem_shared[]; + half *A_shared = mem_shared; + half *B_shared = mem_shared + kSmemSizeA; + half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB; + half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; + half A_shared_warp_[2][WARP_M * INTRIN_K / + WARP_SIZE]; + half B_shared_warp_[2][WARP_N * 32 / + WARP_SIZE]; + half B_shared_warp_tmp_[2][WARP_N * 16 / + WARP_SIZE]; + int cta_offset_m = blockIdx_m * CTA_M; + int cta_offset_n = blockIdx_n * CTA_N; + int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M; + int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N; + + for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++) + C_warp[i] = 0.0; + + int gemm_iters = (K + CTA_K - 1) / CTA_K; + int k_0_0_ld = 0; + int k_0_0 = 0; + constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; +#pragma unroll + for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) + { + global_to_share_one_stage_A_T2(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); + global_to_share_one_stage_B_T2(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); + global_to_share_one_stage_scales_T2( + scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N, + zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N, + N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); + if constexpr (STAGES > 1) + __pipeline_commit(); + } + if constexpr (STAGES > 1) + __pipeline_wait_prior(STAGES - 2); + __syncthreads(); + + share_to_reg_one_stage_A_T2(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0); + share_to_reg_one_stage_B_T2(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0); + constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K; + + for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) + { + int ld_stage = k_0_0_ld % STAGES; + int compute_stage = k_0_0 % STAGES; + half *A_shared_this_compute_stage; + half *B_shared_this_compute_stage; + half *scales_shared_this_compute_stage; + half *zeros_shared_this_compute_stage; + + for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) + { + A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage; + B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage; + scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N; + zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N; + share_to_reg_one_stage_A_T2(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); + if ((iter_k + 1) % kInterleave == 0) + { + if (compute_stage % 2 == 1) + { + share_to_reg_one_stage_B_T2( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); + } + else + { + share_to_reg_one_stage_B_T2( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); + } + } + else + { + if (compute_stage % 2 == 1) + { + share_to_reg_one_stage_B_T2( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); + } + else + { + share_to_reg_one_stage_B_T2( + B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, + B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], + warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); + } + } + __syncthreads(); + half *A_shared_warp = A_shared_warp_[iter_k % 2]; + half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; + for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) + { + for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) + { + mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); + mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); + } + } + + if (iter_k < WARP_K / INTRIN_K - 1) + { + if constexpr (STAGES == 1) + __syncthreads(); + global_to_share_one_stage_A_T2(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); + global_to_share_one_stage_B_T2(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); + } + + if (iter_k == WARP_K / INTRIN_K - 2) + { + if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) + { + __syncthreads(); + } + global_to_share_one_stage_A_T2(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); + global_to_share_one_stage_B_T2(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); + global_to_share_one_stage_scales_T2( + scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N, + zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N, + N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); + if constexpr (STAGES > 1) + { + __pipeline_commit(); + __pipeline_wait_prior(STAGES - 2); + } + compute_stage = (k_0_0 + 1) % STAGES; + __syncthreads(); + } + } + } + for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) + { + for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) + { + for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) + { + int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); + if (write_row < M) + { + *reinterpret_cast( + C + write_row * N + + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = + __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + + ax1_0_1 * 8 + local_id)); + } + }; + } + } +} + +torch::Tensor gemm_forward_cuda_prefill( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scales, + torch::Tensor _zeros) +{ + std::vector output_shape = _in_feats.sizes().vec(); + output_shape.back() = _kernel.size(0) * kInterleave; + int num_in_feats = _in_feats.numel() / _in_feats.size(-1); + int num_in_channels = _in_feats.size(-1); + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto options = + torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + auto options_int = + torch::TensorOptions().dtype(torch::kInt32).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty(output_shape, options); + int num_out_feats = _out_feats.numel() / _out_feats.size(-1); + int num_out_channels = _out_feats.size(-1); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + if (num_out_feats <= 32) + { + constexpr int G = 128; + constexpr int CTA_M = 16; + constexpr int CTA_N = 128; + constexpr int CTA_K = 128; + constexpr int WARP_M = 16; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 2; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } + else if (num_out_feats <= 64) + { + + constexpr int G = 128; + constexpr int CTA_M = 16; + constexpr int CTA_N = 128; + constexpr int CTA_K = 128; + constexpr int WARP_M = 16; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 1; + constexpr int STAGES = 3; + KERNEL_LAUNCH_CODE + } + else if (num_out_feats <= 128) + { + constexpr int G = 128; + constexpr int CTA_M = 32; + constexpr int CTA_N = 128; + constexpr int CTA_K = 128; + constexpr int WARP_M = 32; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 1; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } + else if (num_out_feats <= 192) + { + constexpr int G = 128; + constexpr int CTA_M = 64; + constexpr int CTA_N = 128; + constexpr int CTA_K = 64; + constexpr int WARP_M = 64; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int SPLITK = 1; + constexpr int STAGES = 4; + KERNEL_LAUNCH_CODE + } + else + { + constexpr int G = 128; + constexpr int CTA_M = 64; + constexpr int CTA_N = 128; + constexpr int CTA_K = 64; + constexpr int WARP_M = 64; + constexpr int WARP_N = 32; + constexpr int WARP_K = 64; + constexpr int STAGES = 4; + + constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N); + constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(half); + if (kSmemByteSize >= 99 * 1024) + { + printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); + return _out_feats; + } + int j_factors1 = num_out_channels / CTA_N / 1; + dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1); + dim3 threads_per_block(WARP_SIZE, NUM_WARPS); + auto kernel_func = gemm_w4a16_T2; + cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); + kernel_func<<>>( + in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels); + } + + return _out_feats; +} \ No newline at end of file diff --git a/awq_ext/quantization_new/gemm/gemm_cuda.h b/awq_ext/quantization_new/gemm/gemm_cuda.h new file mode 100644 index 0000000..60c9ece --- /dev/null +++ b/awq_ext/quantization_new/gemm/gemm_cuda.h @@ -0,0 +1,3 @@ +#include + +torch::Tensor gemm_forward_cuda_prefill(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros); diff --git a/awq_ext/quantization_new/gemm/semaphore.h b/awq_ext/quantization_new/gemm/semaphore.h new file mode 100644 index 0000000..acc636f --- /dev/null +++ b/awq_ext/quantization_new/gemm/semaphore.h @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CTA-wide semaphore for inter-CTA synchronization. +class Semaphore +{ +public: + int *lock; + bool wait_thread; + int state; + +public: + /// Implements a semaphore to wait for a flag to reach a given value + __host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_), + wait_thread(thread_id < 0 || thread_id == 0), + state(-1) + { + } + + /// Permit fetching the synchronization mechanism early + __device__ void fetch() + { + if (wait_thread) + { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); +#else + asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); +#endif + } + } + + /// Gets the internal state + __device__ int get_state() const + { + return state; + } + + /// Waits until the semaphore is equal to the given value + __device__ void wait(int status = 0) + { + while (__syncthreads_and(state != status)) + { + fetch(); + } + + __syncthreads(); + } + + /// Updates the lock with the given result + __device__ void release(int status = 0) + { + __syncthreads(); + + if (wait_thread) + { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#else + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#endif + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/awq_ext/quantization_new/gemv/gemv_cuda.cu b/awq_ext/quantization_new/gemv/gemv_cuda.cu new file mode 100644 index 0000000..78d12b4 --- /dev/null +++ b/awq_ext/quantization_new/gemv/gemv_cuda.cu @@ -0,0 +1,329 @@ +/* + * Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv) + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} +*/ + +#include +#include +#include +#include "gemv_cuda.h" +#include "../dequantize.cuh" +#define PACK_FACTOR 8 +#define WARP_SIZE 32 +#define MEM_ACCESS_SIZE 128 + + +static inline __device__ float to_float(half src) +{ + return __half2float(src); +} + +static inline __device__ float to_float(float src) +{ + return src; +} + +static inline __device__ half to_half(float src) +{ + return __float2half(src); +} + +static inline __device__ half to_half(half src) +{ + return src; +} + +// Reduce sum within the warp using the tree reduction algorithm. +template +__device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4]) +{ + // kInterleave = 4 + float fpsum[Num]; + #pragma unroll + for (int i = 0; i < Num; ++i) + { + fpsum[i] = to_float(psum[i]); + } + + #pragma unroll + for (int i = 0; i < Num; ++i) + { + // T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4) + fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16); + fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8); + fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1); + } + __syncthreads(); + int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; + if (lane == 0 || lane == 2 || lane == 4 || lane == 6) + { + #pragma unroll + for (int i = 0; i < Num; ++i) + { + out_smem[warp][i * 4 + lane / 2] = fpsum[i]; + } + } + __syncthreads(); +}; + +__device__ __forceinline__ int make_divisible(int c, int divisor){ + return (c + divisor - 1) / divisor; +} + +template +__global__ void gemv_kernel( + const half* inputs, const uint32_t* weight, const half* scales, const half* zeros, half* outputs, + const int IC, const int OC) +{ + const int kStride = 64; + const int kElemsPerThread = MEM_ACCESS_SIZE / 4; + const int kThreadsNumPerTile = kStride / kElemsPerThread; + // assert(MEM_ACCESS_SIZE == 128); + + static constexpr int kShuffleSize = 32; + static constexpr int kShuffleBasicTile = 2; + static constexpr int kShuffleContinous = 4; + static constexpr int kShuffleStrided = 4; + + constexpr int Num = NPerBlock * Batch; + constexpr int kInterleave = 4; + + half local_inputs[kElemsPerThread]; + uint32_t local_qweights[MEM_ACCESS_SIZE / 32]; + half half_weight_buffer[kElemsPerThread]; + half dequantized_weight[kElemsPerThread * NPerBlock]; + half local_scale[NPerBlock]; + half local_scaled_zeros[NPerBlock]; + + half psum[Num]; + for (int i = 0; i < Num; ++i) + psum[i] = to_half(0.f); + + extern __shared__ uint8_t shmem[]; + float(*out_smem)[Num * kInterleave] = reinterpret_cast(shmem); + + const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave; + const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave; + const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride + + (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread; + const int group_offset = act_k_offset / GroupSize; + // TODO: use make_divisible + const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR; + const half* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC; + const half* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC; + const half* inputs_ptr = inputs + act_k_offset; + + const int act_forward_step = BlockSize * kElemsPerThread / kInterleave; + const int scale_forward_step = act_forward_step / GroupSize * OC; + + // Main loop iteration, each block completes the outputs for several OCs + for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) + { + // Load qweight, scales and scaled_zeros + #pragma unroll + for (int idx = 0; idx < NPerBlock; ++idx) + { + // use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit) + *((float4*)(local_qweights)) = + *((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR)); + local_scale[idx] = *(scale_ptr + idx * kInterleave); + local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave); + + // Map int4 qweight to fp format + #pragma unroll + for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) + { + // Converts 32 bits (8 x int4) to 8 fp16 + dequantize_s4_to_fp16x2(*reinterpret_cast(local_qweights + i), reinterpret_cast(half_weight_buffer + i * PACK_FACTOR)); + } + + // Dequantize (apply s/z) and shuffle elements to match the weight packing format + #pragma unroll + for (int i = 0; i < kShuffleContinous; ++i) + { + #pragma unroll + for (int j = 0; j < kShuffleStrided; ++j) + { + half2 w = + *reinterpret_cast( + half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile + ); + w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx])); + dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) + * NPerBlock + idx] + = w.x; + dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) + * NPerBlock + idx] + = w.y; + } + } + } + #pragma unroll + for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) + { + const half* local_inputs_ptr = inputs_ptr + batch_idx * IC; + #pragma unroll + for (int idx = 0; idx < kElemsPerThread / 8; ++idx) + { + // load activation, 8 halves (128 bits) / step. + *((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8)); + } + // Perform the MACs + #pragma unroll + for (int x = 0; x < NPerBlock / 2; ++x) + { + #pragma unroll + for (int y = 0; y < kElemsPerThread; ++y) + { + *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2) + = __hfma2(*reinterpret_cast(dequantized_weight + y * NPerBlock + x * 2), + __half2half2(local_inputs[y]), + *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2)); + } + } + } + inputs_ptr += act_forward_step; + scale_ptr += scale_forward_step; + zeros_ptr += scale_forward_step; + } + + warp_reduce(psum, out_smem); + + // Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num + for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) + { + int batch_idx = i / (NPerBlock * kInterleave); + int oc_idx = i % (NPerBlock * kInterleave); + float acc = 0.f; + for (int j = 0; j < BlockSize / WARP_SIZE; ++j) + { + acc += out_smem[j][i]; + } + outputs[batch_idx * OC + blk_row_offset + oc_idx] = to_half(acc); + } +} + +/* +Computes GEMV (PyTorch interface). + +Args: + _in_feats: tensor of shape [B, IC]; + _kernel: int tensor of shape [OC, IC // 8]; + _zeros: int tensor of shape [OC, IC // G // 8]; + _scaling_factors: tensor of shape [OC, IC // G]; + blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; + blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; + +Returns: + out_feats: tensor of shape [B, OC]; +*/ +torch::Tensor gemv_forward_cuda_decode( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int m, + int n, + int k, + int group_size) +{ + + std::vector output_shape = _in_feats.sizes().vec(); + output_shape.back() = n; + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty(output_shape, options); + half * out_feats = reinterpret_cast(_out_feats.data_ptr()); + + static constexpr int N_PER_BLOCK = 2; + static constexpr int K_INTERLEAVE = 4; + static constexpr int BLOCK_SIZE = 256; + + dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE); + dim3 num_threads(BLOCK_SIZE); + + // if (group_size == 64) + // { + // gemv_kernel_g64<<>>( + // // pointers + // in_feats, kernel, zeros, scaling_factors, out_feats, + // // constants + // num_in_channels, num_out_channels + // ); + // } + if (group_size == 128) + { + switch (m) + { + case 1: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 2: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 3: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 4: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 5: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 6: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + case 7: + gemv_kernel<<>>( + in_feats, kernel, scaling_factors, zeros, out_feats, k, n + ); + break; + default: + throw std::runtime_error("Unsupported batch size for gemv kernel.\n"); + } + } + else + { + throw std::runtime_error("Unsupported group size for gemv kernel.\n"); + } + return _out_feats; +} + diff --git a/awq_ext/quantization_new/gemv/gemv_cuda.h b/awq_ext/quantization_new/gemv/gemv_cuda.h new file mode 100644 index 0000000..9dd8b32 --- /dev/null +++ b/awq_ext/quantization_new/gemv/gemv_cuda.h @@ -0,0 +1,12 @@ +#pragma once +#include + +torch::Tensor gemv_forward_cuda_decode( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int m, + int n, + int k, + int group_size); diff --git a/scripts/download_wheels.sh b/scripts/download_wheels.sh index e736a44..fc00c2a 100644 --- a/scripts/download_wheels.sh +++ b/scripts/download_wheels.sh @@ -1,7 +1,7 @@ #!/bin/bash # Set variables -AWQ_KERNELS_VERSION="0.0.5" +AWQ_KERNELS_VERSION="0.0.6" RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ_kernels/releases/tags/v${AWQ_KERNELS_VERSION}" # Create a directory to download the wheels diff --git a/setup.py b/setup.py index e76586d..4b24aad 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ os.environ["CC"] = "g++" os.environ["CXX"] = "g++" -AUTOAWQ_KERNELS_VERSION = "0.0.5" +AUTOAWQ_KERNELS_VERSION = "0.0.6" PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1" CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda ROCM_VERSION = os.environ.get("ROCM_VERSION", None) or torch.version.hip @@ -89,7 +89,9 @@ def get_generator_flag(): return generator_flag -def get_compute_capabilities(): +def get_compute_capabilities( + compute_capabilities={75, 80, 86, 89, 90} +): capability_flags = [] if CUDA_VERSION: @@ -103,7 +105,6 @@ def get_compute_capabilities(): ) # Figure out compute capability - compute_capabilities = {75, 80, 86, 89, 90} for cap in compute_capabilities: capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"] @@ -180,6 +181,22 @@ def get_extra_link_args(): ) ) + # only compatible with ampere + arch_flags = get_compute_capabilities({80, 86, 89, 90}) + extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags) + + extensions.append( + CUDAExtension( + "awq_v2_ext", + [ + "awq_ext/pybind_awq_v2.cpp", + "awq_ext/quantization_new/gemv/gemv_cuda.cu", + "awq_ext/quantization_new/gemm/gemm_cuda.cu", + ], + extra_compile_args=extra_compile_args_v2, + ) + ) + extensions.append( CUDAExtension( "exl_ext",