Skip to content

Commit

Permalink
Enable certain CUDA kernels to accept specified cuda stream (#1330)
Browse files Browse the repository at this point in the history
* Done

* fix format

* fix format

* fix format

* fix format

* Address format error and fix default arg bug

* Refine stream argument passing mechanism

* Fix bug

* Delete unused code
  • Loading branch information
jeejeelee authored Aug 22, 2024
1 parent 6ae9859 commit a685654
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 58 deletions.
25 changes: 22 additions & 3 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ def is_on_gpu(tensors):
return on_gpu


def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
stream = torch.cuda.current_stream(tensor.device)
return stream


def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
"""
Get the ctypes pointer from a PyTorch Tensor.
Expand Down Expand Up @@ -973,6 +978,7 @@ def dequantize_blockwise(
f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
)
is_on_gpu([A, absmax, out])
stream = get_tensor_stream(A)
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(
get_ptr(quant_state.code),
Expand All @@ -981,6 +987,7 @@ def dequantize_blockwise(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
stream, # Used the _as_parameter_ attribute of torch.cuda.Stream, Similarly for the following
)
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(
Expand All @@ -990,6 +997,7 @@ def dequantize_blockwise(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
stream,
)
elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(
Expand All @@ -999,6 +1007,7 @@ def dequantize_blockwise(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
stream,
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
Expand Down Expand Up @@ -1176,7 +1185,6 @@ def quantize_4bit(

prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax])

if A.dtype == torch.float32:
if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(
Expand Down Expand Up @@ -1356,6 +1364,7 @@ def dequantize_4bit(

device = pre_call(A.device)
is_on_gpu([A, absmax, out])
stream = get_tensor_stream(A)
if out.dtype == torch.float32:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(
Expand All @@ -1365,6 +1374,7 @@ def dequantize_4bit(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else:
lib.cdequantize_blockwise_fp32_nf4(
Expand All @@ -1374,6 +1384,7 @@ def dequantize_4bit(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
elif out.dtype == torch.float16:
if quant_state.quant_type == "fp4":
Expand All @@ -1384,6 +1395,7 @@ def dequantize_4bit(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else:
lib.cdequantize_blockwise_fp16_nf4(
Expand All @@ -1393,6 +1405,7 @@ def dequantize_4bit(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
elif out.dtype == torch.bfloat16:
if quant_state.quant_type == "fp4":
Expand All @@ -1403,6 +1416,7 @@ def dequantize_4bit(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else:
lib.cdequantize_blockwise_bf16_nf4(
Expand All @@ -1412,6 +1426,7 @@ def dequantize_4bit(
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
Expand Down Expand Up @@ -1518,7 +1533,8 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
stream = get_tensor_stream(A)
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream)
post_call(prev_device)
return out

Expand Down Expand Up @@ -2002,7 +2018,7 @@ def gemv_4bit(
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)

stream = get_tensor_stream(A)
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(
Expand All @@ -2018,6 +2034,7 @@ def gemv_4bit(
ldb,
ldc,
ct.c_int32(state.blocksize),
stream,
)
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(
Expand All @@ -2033,6 +2050,7 @@ def gemv_4bit(
ldb,
ldc,
ct.c_int32(state.blocksize),
stream,
)
elif A.dtype == torch.float32:
lib.cgemm_4bit_inference_naive_fp32(
Expand All @@ -2048,6 +2066,7 @@ def gemv_4bit(
ldb,
ldc,
ct.c_int32(state.blocksize),
stream,
)
else:
raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
Expand Down
41 changes: 20 additions & 21 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ void quantize(float *code, float *A, unsigned char *out, int n)
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

void dequantize(float *code, unsigned char *A, float *out, int n)
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream)
{
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
kDequantize<<<num_blocks, 1024, 0, stream>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

Expand Down Expand Up @@ -76,16 +76,16 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream)
{
// printf("stream==%d\n",stream);
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;

if(DATA_TYPE > 0)
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n);
else
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n);

CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
Expand Down Expand Up @@ -724,12 +724,11 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}

template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream)
{

int num_blocks = (m+3)/4;

kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
kgemm_4bit_inference_naive<T, 128, BITS><<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

Expand All @@ -753,9 +752,9 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);

template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);

//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
Expand Down Expand Up @@ -795,15 +794,15 @@ template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __n
template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);

template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n);
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);
template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream);

#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
Expand Down
7 changes: 4 additions & 3 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef ops_H
#define ops_H

#include <cstdint>
#include <stdio.h>
#include <iostream>
#include <assert.h>
Expand Down Expand Up @@ -142,9 +143,9 @@ class ContextCusparse
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);

void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream);

template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
Expand Down Expand Up @@ -195,7 +196,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows

template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);

template <typename T, int FUNC> void func(T *A, T *B, T value, long n);

Expand Down
Loading

0 comments on commit a685654

Please sign in to comment.