Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated to newer base to support qlora #2

Open
wants to merge 65 commits into
base: cmake_windows
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
3ac5840
Added fp4 quant/dequant and dequant optimizations.
TimDettmers Feb 4, 2023
160a835
Forward matmul_fp4 tests pass.
TimDettmers Feb 5, 2023
13c0a4d
Backward matmul_fp4 passes.
TimDettmers Feb 5, 2023
cfe4705
Added matmul_fp4 to the benchmark.
TimDettmers Feb 5, 2023
c361f84
Fixed matmul_fp4 transpose.
TimDettmers Feb 5, 2023
c0c352b
Added bias test for LinearFP4 and basic test.
TimDettmers Feb 5, 2023
7f0773a
Added backprop test for Linear8bitLt and LinearFP4.
TimDettmers Feb 5, 2023
c93a90d
Fixed FP4 import and data type conversion in backward.
TimDettmers Feb 14, 2023
9851a10
Added cast to fp4 layer for speed.
TimDettmers Feb 24, 2023
6c31a5f
t5 model fix
artidoro Feb 27, 2023
6981052
Some small changes.
TimDettmers Mar 27, 2023
8645d1f
Added normal quant.
TimDettmers Mar 30, 2023
c4cfe4f
Added bf16 Adam.
TimDettmers Apr 1, 2023
51a21df
Added 8-bit compression to quantization statistics.
TimDettmers Apr 1, 2023
2dd5d69
Generalized FP4 data type.
TimDettmers Apr 2, 2023
0d332a6
Added normal with extra value.
TimDettmers Apr 2, 2023
4ad999d
Added quantization tree generation.
TimDettmers Apr 2, 2023
64cc059
First draft of NF4.
TimDettmers Apr 2, 2023
4ea489d
Refactor FP4 into 4Bit and integrate NF4 data type.
TimDettmers Apr 3, 2023
1ccb7bd
Fixed ParamsIn4 init; fixed PyTorch 2.0 test failure.
TimDettmers Apr 4, 2023
e9fa03b
Some fixed for loading PEFT modules with Params4bit.
TimDettmers Apr 7, 2023
b8ea2b4
Fixed bias conversion in Linear4bit
TimDettmers Apr 12, 2023
7dc198f
Added 32-bit optimizer for bfloat16 gradients.
TimDettmers Apr 18, 2023
0f9d302
Added nested quantization for blockwise quantization.
TimDettmers Apr 19, 2023
6bfd7a4
Initial template.
TimDettmers Apr 25, 2023
6e2544d
Added cutlass example.
TimDettmers Apr 25, 2023
84964db
CUTLASS compiles.
TimDettmers Apr 26, 2023
0afc8e9
Best attempt at cutlass3.
TimDettmers Apr 27, 2023
d1c4c20
Added non-cutlass template.
TimDettmers Apr 27, 2023
9cab14a
Adedd pipeline draft.
TimDettmers Apr 27, 2023
c1bfb21
First baseline kernel.
TimDettmers Apr 29, 2023
3aef783
Added template refactor.
TimDettmers Apr 29, 2023
f6df4ae
Added fp16 and thread/item template.
TimDettmers Apr 29, 2023
f3e97cc
New implementation for batch size 1.
TimDettmers Apr 29, 2023
cad8399
Added bit template.
TimDettmers Apr 29, 2023
21723f7
4-bit draft.
TimDettmers Apr 30, 2023
ad07d25
Slow tensor core solution.
TimDettmers May 1, 2023
604bb3f
Slow non-vector 530.
TimDettmers May 1, 2023
c35ed09
Double frag 440.
TimDettmers May 1, 2023
e01d4e0
Fixed bank conflicts in non-vector load 422.
TimDettmers May 1, 2023
30d03e0
64 threads, high smem, 434.
TimDettmers May 1, 2023
cabcd9b
Halved shared memory 466.
TimDettmers May 1, 2023
7cc8ff4
Warp specalization 362.
TimDettmers May 1, 2023
3d4a2ea
16x16 240.
TimDettmers May 1, 2023
7bfa09d
8x32 240 6 warps.
TimDettmers May 1, 2023
f9bfea8
Baseline for debugging.
TimDettmers May 2, 2023
9192c9d
Tighter and scaled error analysis.
TimDettmers May 2, 2023
9aa232c
Initial.
TimDettmers May 2, 2023
394749d
Correct implementation 240.
TimDettmers May 2, 2023
4decb3c
Removed uncessary sync.
TimDettmers May 2, 2023
89cccd8
A tile multi-tiling.
TimDettmers May 2, 2023
77f15fd
Shared memory efficient 240.
TimDettmers May 2, 2023
869b7e8
Warp multi-specialization 240.
TimDettmers May 2, 2023
264a948
4-bit draft; 128 vector load 240.
TimDettmers May 2, 2023
ec38ba9
Added paging.
TimDettmers May 6, 2023
44d68ff
Added paged optimizers.
TimDettmers May 6, 2023
41a9c70
Changed prefetching.
TimDettmers May 7, 2023
f64cfe6
Fixed prefetch bug for non-paged tensors; added benchmark.
TimDettmers May 7, 2023
675baa7
Merge remote-tracking branch 'origin/main' into merge
TimDettmers May 7, 2023
4bd1151
Fixed gradient accumulation test.
TimDettmers May 7, 2023
2bce175
Fixed Makefile.
TimDettmers May 24, 2023
1b8772a
Added PagedLion and bf16 Lion.
TimDettmers May 24, 2023
0f40fa3
Bumped version.
TimDettmers May 24, 2023
46b184b
Merge remote-tracking branch 'remotes/source/main' into cmake_windows
stoperro May 25, 2023
e02f078
Disables cudaMemPrefetchAsync() in unsupported systems (e.g. Windows)…
stoperro May 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,14 @@ Deprecated:
Features:
- Added Int8 SwitchBack layers
- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`)


### 0.39.0


Features:
- 4-bit matrix multiplication for Float4 and NormalFloat4 data types.
- Added 4-bit quantization routines
- Doubled quantization routines for 4-bit quantization
- Paged optimizers for Adam and Lion.
- bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states.
26 changes: 6 additions & 20 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST)))
ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH)))

GPP:= /usr/bin/g++
#GPP:= /sw/gcc/11.2.0/bin/g++
ifeq ($(CUDA_HOME),)
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
endif
Expand All @@ -12,6 +13,7 @@ CUDA_VERSION:=
endif



NVCC := $(CUDA_HOME)/bin/nvcc

###########################################
Expand All @@ -23,8 +25,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c

INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib

# NVIDIA NVCC compilation flags
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
Expand All @@ -38,11 +39,6 @@ CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler
CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler

# Later versions of CUDA support the new architectures
CC_CUDA10x += -gencode arch=compute_75,code=sm_75

CC_CUDA110 := -gencode arch=compute_75,code=sm_75
CC_CUDA110 += -gencode arch=compute_80,code=sm_80

CC_CUDA11x := -gencode arch=compute_75,code=sm_75
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
Expand All @@ -59,21 +55,11 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90


all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
all: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)

cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)

cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE_10x) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)

cuda110_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
matmul,
matmul_cublas,
mm_cublas,
matmul_4bit
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
Expand Down
72 changes: 68 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Tuple, Optional
from typing import Tuple, Optional, List

import torch

Expand Down Expand Up @@ -424,10 +424,10 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA)
ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None]
ctx.tensors = [None, None, A]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)

Expand All @@ -440,7 +440,7 @@ def backward(ctx, grad_output):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors
CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
Expand Down Expand Up @@ -486,6 +486,65 @@ def backward(ctx, grad_output):

return grad_A, grad_B, None, grad_bias, None


class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = state[1]
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)


# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)

# 3. Save state
ctx.state = state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, B)
else:
ctx.tensors = (None, None)

return output

@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
A, B = ctx.tensors
state = ctx.state

grad_A, grad_B, grad_bias = None, None, None

if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t())

return grad_A, grad_B, None, grad_bias, None


def matmul(
A: tensor,
B: tensor,
Expand All @@ -498,3 +557,8 @@ def matmul(
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, bias, state)


def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
assert quant_state is not None
return MatMul4Bit.apply(A, B, out, bias, quant_state)
7 changes: 5 additions & 2 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_g32
lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
except AttributeError as ex:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False
print(str(ex))


# print the setup details after checking for errors so we do not print twice
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
Expand Down
4 changes: 4 additions & 0 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self):
raise RuntimeError("Call get_instance() instead")

def generate_instructions(self):
if getattr(self, 'error', False): return
print(self.error)
self.error = True
if self.cuda is None:
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
Expand Down Expand Up @@ -100,6 +103,7 @@ def initialize(self):
self.has_printed = False
self.lib = None
self.initialized = False
self.error = False

def run_cuda_setup(self):
self.initialized = True
Expand Down
Loading