Skip to content

Commit

Permalink
Don't crash Python interpreter via assert(false) (#998)
Browse files Browse the repository at this point in the history
  • Loading branch information
akx authored Jan 30, 2024
1 parent 706ec24 commit 29a637b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
5 changes: 4 additions & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,7 +1944,10 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
)

if has_error == 1:
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")

if has_error:
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')

Expand Down
13 changes: 4 additions & 9 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <cassert>
#include <common.h>

#define ERR_NOT_IMPLEMENTED 100


using namespace BinSearch;
using std::cout;
Expand Down Expand Up @@ -421,14 +423,7 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);

return 0;
return ERR_NOT_IMPLEMENTED;
#else
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
Expand Down Expand Up @@ -484,7 +479,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
printf("error detected");

return has_error;
#endif
#endif // NO_CUBLASLT
}

int fill_up_to_nearest_multiple(int value, int multiple)
Expand Down

0 comments on commit 29a637b

Please sign in to comment.