Skip to content

Commit

Permalink
CUDA script CMAKE fixes. some Cpuonly fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
acpopescu committed Mar 29, 2023
1 parent f9a7774 commit 8fce35b
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 125 deletions.
96 changes: 63 additions & 33 deletions CMakelists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,61 @@ set(FILES_CUDA csrc/ops.cu csrc/kernels.cu)
set(FILES_CPP csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.c)

option(MAKE_CUDA_BUILD "Build using CUDA" ON)
option(NO_CUBLASLT "Don't use CUBLAST" OFF)
set(CUDA_VERSION "11.7" CACHE STRING "CUDA Version DLL Name: 11.0, 11.7, 11.6")
set(CUDA_VERSION_FIXED "")
set(CUDA_VERSION_MAJOR "")
string(REPLACE "." "" CUDA_VERSION_FIXED "${CUDA_VERSION}")
string(REGEX MATCH "[^\.]+" CUDA_VERSION_MAJOR "${CUDA_VERSION}")

# Later versions of CUDA support the new architectures
set(CC_CUDA10x 75)
set(CC_CUDA110 75 80)
set(CC_CUDA11x 75 80 86)
set(CC_CUDA12x 89 90)
option(USE_AVX2 "AVX2 Instruction Set" ON)
option(USE_AVX "AVX Instruction Set" ON)

if(USE_AVX2)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
elseif(USE_AVX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX")
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /fp:fast")

add_definitions(-DUSE_AVX2 -DUSE_AVX)
if( MAKE_CUDA_BUILD )
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if(${CUDA_VERSION} EQUAL "11.0")
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA110})
elseif(${CUDA_VERSION_MAJOR} EQUAL "11")
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA11x})
elseif(${CUDA_VERSION_MAJOR} EQUAL "12")
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA12x})
else()
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA10x})
set(NO_CUBLASLT ON)
endif()
# Later versions of CUDA support the new architectures
set(CC_CUDA10x 75)
set(CC_CUDA110 75 80)
set(CC_CUDA11x 75 80 86)
set(CC_CUDA12x 89 90)

option(NO_CUBLASLT "Don't use CUBLAST" OFF)
set(CUDA_TARGET_ARCH_FEATURE_LEVEL "11.x" CACHE STRING
"CUDA Target Architectures by Feature Level. DLL name is autodetected from installed cuda compiler.\n \
Examples : 10.0, 11.0, 11.x, 12.x\n \
\n \
Note : to change the CUDA compiler you're using and the DLL NAME\n \
- Clean the build folder\n \
- when promped for additional parameters of the Visual C Generator add 'cuda=11.6'\n ")

set(CUDA_VERSION_DLLNAME "")
set(CUDA_VERSION_TARGET_FEATURE_MAJOR "")
string(REGEX MATCH "[0123456789]+\.[0123456789]+" CUDA_VERSION_DLLNAME "${CMAKE_CUDA_COMPILER_VERSION}")
string(REPLACE "." "" CUDA_VERSION_DLLNAME "${CUDA_VERSION_DLLNAME}")
string(REGEX MATCH "[^\.]+" CUDA_VERSION_TARGET_FEATURE_MAJOR "${CUDA_TARGET_ARCH_FEATURE_LEVEL}")

message(CONFIGURE_LOG "\nConfiguring using Cuda Compiler ${CMAKE_CUDA_COMPILER_VERSION}; Visual Studio Integration: ${CMAKE_VS_PLATFORM_TOOLSET_CUDA}\n")

if(${CUDA_TARGET_ARCH_FEATURE_LEVEL} STREQUAL "11.0")
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA110})
elseif(${CUDA_VERSION_TARGET_FEATURE_MAJOR} STREQUAL "11")
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA11x})
elseif(${CUDA_VERSION_TARGET_FEATURE_MAJOR} STREQUAL "12")
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA12x})
else()
set(CMAKE_CUDA_ARCHITECTURES ${CC_CUDA10x})
set(NO_CUBLASLT ON)
endif()


message(CONFIGURE_LOG " CUDA Targeting feature level ${CUDA_TARGET_ARCH_FEATURE_LEVEL}, with architectures ${CMAKE_CUDA_ARCHITECTURES}")

set (LIBBITSANDBYTESNAME "libbitsandbytes_cuda${CUDA_VERSION_DLLNAME}")
if(NO_CUBLASLT)
set (LIBBITSANDBYTESNAME "libbitsandbytes_cuda${CUDA_VERSION_DLLNAME}_nocublaslt")
endif(NO_CUBLASLT)

message(CONFIGURE_LOG " Shared library name being used: ${LIBBITSANDBYTESNAME}")

if(NOT DEFINED CMAKE_CUDA_STANDARD)
set(CMAKE_CUDA_STANDARD 11)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
Expand All @@ -59,11 +86,9 @@ if( MAKE_CUDA_BUILD )
POSITION_INDEPENDENT_CODE ON
CUDA_SEPARABLE_COMPILATION ON
PREFIX ""
OUTPUT_NAME "libbitsandbytes_cuda${CUDA_VERSION_FIXED}"
OUTPUT_NAME "${LIBBITSANDBYTESNAME}"
LINKER_LANGUAGE C
WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
# add_link_options(-Wl,--export-all-symbols)
# add_link_options( -Wl,--add-stdcall-alias)
target_include_directories(libbitsandbytes_cuda PRIVATE
"${PROJECT_SOURCE_DIR}/csrc/"
"${PROJECT_SOURCE_DIR}/include/"
Expand All @@ -75,14 +100,18 @@ if( MAKE_CUDA_BUILD )
curand
cusparse
)
add_custom_command(TARGET libbitsandbytes_cuda POST_BUILD # Adds a post-build event to MyTest
COMMAND ${CMAKE_COMMAND} -E copy_directory # which executes "cmake - E copy_if_different..."
$<TARGET_FILE_DIR:libbitsandbytes_cuda>
"${PROJECT_SOURCE_DIR}/bitsandbytes" )


endif(MAKE_CUDA_BUILD)

add_library(libbitsandbytes_cpu SHARED
${FILES_CPP}
)
set_source_files_properties(${FILES_CPP} PROPERTIES LANGUAGE CXX)
#add_link_options(-Wl,--export-all-symbols)
#add_link_options(-Wl,--add-stdcall-alias)
set_target_properties(libbitsandbytes_cpu PROPERTIES
POSITION_INDEPENDENT_CODE ON
WINDOWS_EXPORT_ALL_SYMBOLS TRUE
Expand All @@ -92,6 +121,7 @@ target_include_directories(libbitsandbytes_cpu PRIVATE
"${PROJECT_SOURCE_DIR}/csrc/"
"${PROJECT_SOURCE_DIR}/include/"
)



add_custom_command(TARGET libbitsandbytes_cpu POST_BUILD # Adds a post-build event to MyTest
COMMAND ${CMAKE_COMMAND} -E copy_directory # which executes "cmake - E copy_if_different..."
$<TARGET_FILE_DIR:libbitsandbytes_cpu>
"${PROJECT_SOURCE_DIR}/bitsandbytes" )
2 changes: 1 addition & 1 deletion bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def evaluate_cuda_setup():
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('='*80)
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
if not torch.cuda.is_available(): return 'libbitsandbytes_cpu'+SHARED_LIB_EXTENSION, None, None, None, None

cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path()
Expand Down
24 changes: 15 additions & 9 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ def get_instance(cls):

def get_context(self, device):
if device.index not in self.context:
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
self.context[device.index] = ct.c_void_p(lib.get_context())
torch.cuda.set_device(prev_device)
if torch.cuda.is_available():
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
self.context[device.index] = ct.c_void_p(lib.get_context())
torch.cuda.set_device(prev_device)
else:
self.context[device.index] = ct.c_void_p(lib.get_context())
return self.context[device.index]


class Cusparse_Context:
_instance = None

Expand Down Expand Up @@ -302,13 +304,17 @@ def get_ptr(A: Tensor) -> ct.c_void_p:


def pre_call(device):
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device
if torch.cuda.is_available():
prev_device = torch.cuda.current_device()
torch.cuda.set_device(device)
return prev_device
else:
return device


def post_call(prev_device):
torch.cuda.set_device(prev_device)
if torch.cuda.is_available():
torch.cuda.set_device(prev_device)


def get_transform_func(dtype, orderA, orderOut, transpose=False):
Expand Down
33 changes: 23 additions & 10 deletions compile_from_source.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Windows

CPU NOT TESTED
cpu: most tests fail, this library requires CUDA to use the special bits.

Ensure you have your environment you want to bring in bitsandbytes. (a bloom setup, textgen-ui, etc) - via conda.
I'd suggest to install MAMBA and use it, as it's way faster.
Expand Down Expand Up @@ -30,6 +30,8 @@ dependencies:
- jupyter
- notebook
- pytest
- einops
- scipy
```

2. Then open POWERSHELL
Expand All @@ -41,16 +43,27 @@ mamba env activate mycompileenv
```
At this point - select your visual studio installation - aka hit 1

3. Go into your bitsandbytes folder and run `cmake-gui .`
4. Make sure you put the build folder correctly, append "build" to Where to build the binaries
3. Go into your bitsandbytes folder and run `cmake-gui -S . -B ./build`
5. Hit Configure
6. Set your CUDA_VERSION to whatever you have. If you deselect MAKE_CUDA_BUILD, leave as is
7. Hit Configure again, then Generate
8. Open Visual Studio and select Release as configuration. Build Solution
9. copy everything from `build\Release\*.*` over in the `bitsandbytes` folder (the one with the python modules)
10. run tests `python -m pytest`. You may need to use `mamba` to install other modules
11. build wheel `mamba install build` and then `python -m build --wheel`
12. install wheel `pip install .\dist\*.whl`
6. Set `cuda=11.7`, nothing, or other version in the `Optional toolset to use` when selecting the generator to . You can leave it blank. If you don't see the generate, delete the `build` folder and run cmake-gui again with the above command line.

The `Optional toolset to use` will determine the dll name. You need to have that CUDA toolkit and VS integration installed using the NVIDIA installer
7. You should see in the log some info:
```
CONFIGURE_LOG
Configuring using Cuda Compiler 11.7.64; Visual Studio Integration: 11.7
CONFIGURE_LOG CUDA Targeting feature level 11.x, with architectures 52
CONFIGURE_LOG Shared library name being used: libbitsandbytes_cuda117
```
7. Set `CUDA_TARGET_ARCH_FEATURE_LEVEL` to you desired feature level. Supported: 10.x, 11.0, 11.x, 12.x
8. `CMAKE_CUDA_ARCHITECTURES` is overwritten by `CUDA_TARGET_ARCH_FEATURE_LEVEL` for now
9. Select other options. Hit Generate.
10. Open Visual Studio and select Release as configuration. Build Solution
11. Everything should be copied in the right place
12. run tests `python -m pytest`. You may need to use `mamba` to install other modules
13. build wheel `mamba install build` and then `python -m build --wheel`
14. install wheel `pip install .\dist\*.whl`


## Linux
Expand Down
10 changes: 9 additions & 1 deletion csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@ using namespace BinSearch;

#define BLOCK_SIZE 16384

#if defined(USE_AVX) || defined(USE_AVX2)
#define INSTR_SET AVX
#elif defined(USE_SSE41) || defined(USE_SSE42)
#define INSTR_SET SSE
#else
#define INSTR_SET Scalar
#endif

struct quantize_block_args {
BinAlgo<AVX, float, Direct2> *bin_searcher;
BinAlgo<INSTR_SET, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
num_blocks += n % blocksize == 0 ? 0 : 1;

const uint32 elements_code = 256;
BinAlgo<AVX, float, Direct2> bin_searcher(code, elements_code);
BinAlgo<INSTR_SET, float, Direct2> bin_searcher(code, elements_code);

int thread_wave_size = 256;
std::vector<std::future<void>> wave_storage;
Expand Down
9 changes: 9 additions & 0 deletions include/SIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ template <> struct InstrFloatTraits<SSE, double>
typedef __m128d vec_t;
};

template<> struct InstrFloatTraits<Scalar, float>
{
typedef float vec_t;
};
template<> struct InstrFloatTraits<Scalar, double>
{
typedef double vec_t;
};

template <InstrSet I, typename T>
struct FTOITraits
{
Expand Down
7 changes: 6 additions & 1 deletion tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from typing import List, NamedTuple

import platform
import pytest

import bitsandbytes as bnb
Expand Down Expand Up @@ -91,6 +91,11 @@ def test_full_system():

# if CONDA_PREFIX exists, it has priority before all other env variables
# but it does not contain the library directly, so we need to look at the a sub-folder

# not testing windows platform
if(platform.system()=="Windows"):
return

version = ""
if "CONDA_PREFIX" in os.environ:
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
Expand Down
Loading

0 comments on commit 8fce35b

Please sign in to comment.