diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 19f224391..9917e326e 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -2,7 +2,7 @@ import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional, List +from typing import Tuple, Optional, Callable from warnings import warn import torch @@ -14,9 +14,6 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor - - # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py @@ -56,7 +53,10 @@ def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) -def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): +def get_inverse_transform_indices( + transform_tile: Callable[[torch.Tensor], torch.Tensor], + tile_size: Tuple[int, int], +): """ Compute a permutation of indices that invert the specified (tiled) matrix transformation @@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function): # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod - def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None): + def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -549,10 +549,10 @@ def backward(ctx, grad_output): def matmul( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, bias=None ): @@ -562,7 +562,7 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None): +def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index af32819df..a5931ef5e 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -34,9 +34,9 @@ # not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt system = platform.system() if system == 'Windows': - CUDA_RUNTIME_LIBS: list = ["nvcuda.dll"] + CUDA_RUNTIME_LIBS = ["nvcuda.dll"] else: # Linux or other - CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2'] + CUDA_RUNTIME_LIBS = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2'] # this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths backup_paths = [] diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 739d922a4..25aa4e531 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -12,7 +12,7 @@ import numpy as np from functools import reduce # Required in Python 3 -from typing import Tuple, Any, Dict +from typing import Tuple, Any, Dict, Optional from torch import Tensor from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict @@ -27,71 +27,83 @@ def prod(iterable): if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" - str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) - str2optimizer32bit["momentum"] = ( - lib.cmomentum32bit_grad_32, - lib.cmomentum32bit_grad_16, - ) - str2optimizer32bit["rmsprop"] = ( - lib.crmsprop32bit_grad_32, - lib.crmsprop32bit_grad_16, - ) - str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) - str2optimizer32bit["adagrad"] = ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, - ) + str2optimizer32bit = { + "adam": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "momentum": ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), + } + + str2optimizer8bit = { + "adam": ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ), + "momentum": ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop_static_8bit_grad_32, + lib.crmsprop_static_8bit_grad_16, + ), + "lion": ( + lib.clion_static_8bit_grad_32, + lib.clion_static_8bit_grad_16, + ), + "lamb": ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ), + "lars": ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ), + } + + str2optimizer8bit_blockwise = { + "adam": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "momentum": ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + ), + "rmsprop": ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + ), + } - str2optimizer8bit = {} - str2optimizer8bit["adam"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["momentum"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - str2optimizer8bit["rmsprop"] = ( - lib.crmsprop_static_8bit_grad_32, - lib.crmsprop_static_8bit_grad_16, - ) - str2optimizer8bit["lion"] = ( - lib.clion_static_8bit_grad_32, - lib.clion_static_8bit_grad_16, - ) - str2optimizer8bit["lamb"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["lars"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - - str2optimizer8bit_blockwise = {} - str2optimizer8bit_blockwise["adam"] = ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - lib.cadam_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["momentum"] = ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["rmsprop"] = ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["lion"] = ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["adagrad"] = ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - ) class GlobalPageManager: _instance = None @@ -400,7 +412,8 @@ def is_on_gpu(tensors): raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') return on_gpu -def get_ptr(A: Tensor) -> ct.c_void_p: + +def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: """ Get the ctypes pointer from a PyTorch Tensor. @@ -521,7 +534,7 @@ def nvidia_transform( return out, new_state -def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: +def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. @@ -626,8 +639,8 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: - qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key))) + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) @@ -694,7 +707,14 @@ def to(self, device): self.state2.code = self.state2.code.to(device) -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: +def quantize_blockwise( + A: Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of size 4096 values. @@ -769,10 +789,10 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou def dequantize_blockwise( A: Tensor, - quant_state: QuantState = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, blocksize: int = 4096, nested=False ) -> Tensor: @@ -891,17 +911,17 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): +def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): +def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) def quantize_4bit( A: Tensor, - absmax: Tensor = None, - out: Tensor = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_type='fp4', @@ -987,13 +1007,13 @@ def quantize_4bit( return out, state -def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: +def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') -def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: +def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: +def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1070,7 +1090,11 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = else: return out -def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: +def quantize( + A: Tensor, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, +) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -1086,10 +1110,10 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: def dequantize( A: Tensor, - state: Tuple[Tensor, Tensor] = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + state: Optional[Tuple[Tensor, Tensor]] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> Tensor: assert state is not None or absmax is not None if code is None and state is None: @@ -1104,7 +1128,7 @@ def dequantize( return out * state[0] -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: +def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: ''' Quantizes input tensor to 8-bit. @@ -1133,7 +1157,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: ''' Dequantizes the 8-bit tensor to 32-bit. @@ -1171,11 +1195,11 @@ def optimizer_update_32bit( eps: float, step: int, lr: float, - state2: Tensor = None, + state2: Optional[torch.Tensor] = None, beta2: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, skip_zeros=False, ) -> None: @@ -1274,7 +1298,7 @@ def optimizer_update_8bit( new_max2: Tensor, weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, ) -> None: """ @@ -1603,7 +1627,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 def gemv_4bit( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, state=None @@ -1663,7 +1687,7 @@ def gemv_4bit( def igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): @@ -1752,7 +1776,7 @@ def igemm( def batched_igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 0b1dc5c6f..b1f6deb21 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -145,7 +145,7 @@ def __new__( cls, data: Optional[torch.Tensor] = None, requires_grad=True, - quant_state: QuantState = None, + quant_state: Optional[QuantState] = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4', diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index de07ac647..67b45f4a5 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -162,7 +162,7 @@ def __init__( ): super().__init__(in_features, out_features, bias, device, dtype) - if not is_triton_available: + if not is_triton_available(): raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 0dff351e0..06b0748ff 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -2,6 +2,7 @@ import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 +from typing import Optional import torch @@ -14,7 +15,6 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor class MatMulFP8Mixed(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs @@ -389,19 +389,38 @@ def get_block_sizes(input_matrix, weight_matrix): return bsz, bsz2 -def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + +def matmul_fp8_global( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) -def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + +def matmul_fp8_mixed( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + def switchback_bnb( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, bias=None ): diff --git a/pyproject.toml b/pyproject.toml index 74d17dd90..c73f579e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,4 +34,12 @@ ignore-init-module-imports = true # allow to expose in __init__.py via imports combine-as-imports = true detect-same-package = true force-sort-within-sections = true -known-first-party = ["bitsandbytes"] \ No newline at end of file +known-first-party = ["bitsandbytes"] + +[[tool.mypy.overrides]] +module = "triton.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "scipy.stats" +ignore_missing_imports = true