Skip to content

Commit

Permalink
Merge pull request #1 from IlyasMoutawwakil/exllama-kernels
Browse files Browse the repository at this point in the history
Exllama kernels
  • Loading branch information
casper-hansen authored Jan 21, 2024
2 parents a22d67a + cb74e4e commit fc700a8
Show file tree
Hide file tree
Showing 35 changed files with 4,469 additions and 19 deletions.
58 changes: 58 additions & 0 deletions awq_ext/exllama/cu_compat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh

// atomicAdd for half types, to support CC < 7.x

__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;

do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}

// atomicAdd for half2 types

__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}

//

#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)

__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }

#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif

#endif
#endif

#endif
75 changes: 75 additions & 0 deletions awq_ext/exllama/cuda_buffers.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#define _cuda_buffers_cu
#include "cuda_buffers.cuh"

CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;

CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);

cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}

CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}

CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}

void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);

g_buffers[_device] = buffers;
}

void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}
55 changes: 55 additions & 0 deletions awq_ext/exllama/cuda_buffers.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>

const int CUDA_MAX_DEVICES = 16;

// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif

class CudaBuffers
{
public:
int device;

half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8

cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;

CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};

CudaBuffers* get_buffers(const int device_index);

void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);

void cleanup_buffers_cuda();

#endif
63 changes: 63 additions & 0 deletions awq_ext/exllama/cuda_func/column_remap.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#include "column_remap.cuh"
#include "../util.cuh"

const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;

__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;

int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;

int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;

int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;

while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}

// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w

void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);

dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);

column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}
19 changes: 19 additions & 0 deletions awq_ext/exllama/cuda_func/column_remap.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _column_remap_cuh
#define _column_remap_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>

void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);

#endif
Loading

0 comments on commit fc700a8

Please sign in to comment.