Skip to content

Commit

Permalink
Fix some issues found by Mypy (#995)
Browse files Browse the repository at this point in the history
* Fix erroneous type aliasing

* Fix `Optional` typings (see PEP 484)

* Add Mypy ignores

* Fix Mypy complaints for method tables

* Fix type for get_ptr

* Fix various Mypy errors

* Fix missed call to is_triton_available
  • Loading branch information
akx authored Jan 29, 2024
1 parent 32be289 commit a8c9dfa
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 117 deletions.
22 changes: 11 additions & 11 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Tuple, Optional, List
from typing import Tuple, Optional, Callable
from warnings import warn

import torch
Expand All @@ -14,9 +14,6 @@
def prod(iterable):
return reduce(operator.mul, iterable, 1)

tensor = torch.Tensor


# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py

Expand Down Expand Up @@ -56,7 +53,10 @@ def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)


def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int],
):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation
Expand Down Expand Up @@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
Expand Down Expand Up @@ -549,10 +549,10 @@ def backward(ctx, grad_output):


def matmul(
A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
):
Expand All @@ -562,7 +562,7 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)


def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
system = platform.system()
if system == 'Windows':
CUDA_RUNTIME_LIBS: list = ["nvcuda.dll"]
CUDA_RUNTIME_LIBS = ["nvcuda.dll"]
else: # Linux or other
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']
CUDA_RUNTIME_LIBS = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']

# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
backup_paths = []
Expand Down
Loading

0 comments on commit a8c9dfa

Please sign in to comment.