From f9414e09da0512e3ae7e81d130b94627211c5492 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 24 Feb 2024 16:42:30 +0100 Subject: [PATCH] Fix casting --- awq_ext/quantization_new/gemv/gemv_cuda.cu | 27 +++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/awq_ext/quantization_new/gemv/gemv_cuda.cu b/awq_ext/quantization_new/gemv/gemv_cuda.cu index 5c49626..78d12b4 100644 --- a/awq_ext/quantization_new/gemv/gemv_cuda.cu +++ b/awq_ext/quantization_new/gemv/gemv_cuda.cu @@ -33,6 +33,27 @@ #define WARP_SIZE 32 #define MEM_ACCESS_SIZE 128 + +static inline __device__ float to_float(half src) +{ + return __half2float(src); +} + +static inline __device__ float to_float(float src) +{ + return src; +} + +static inline __device__ half to_half(float src) +{ + return __float2half(src); +} + +static inline __device__ half to_half(half src) +{ + return src; +} + // Reduce sum within the warp using the tree reduction algorithm. template __device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4]) @@ -42,7 +63,7 @@ __device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem) #pragma unroll for (int i = 0; i < Num; ++i) { - fpsum[i] = static_cast(psum[i]); + fpsum[i] = to_float(psum[i]); } #pragma unroll @@ -97,7 +118,7 @@ __global__ void gemv_kernel( half psum[Num]; for (int i = 0; i < Num; ++i) - psum[i] = static_cast(0.f); + psum[i] = to_half(0.f); extern __shared__ uint8_t shmem[]; float(*out_smem)[Num * kInterleave] = reinterpret_cast(shmem); @@ -199,7 +220,7 @@ __global__ void gemv_kernel( { acc += out_smem[j][i]; } - outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast(acc); + outputs[batch_idx * OC + blk_row_offset + oc_idx] = to_half(acc); } }