Skip to content

Commit

Permalink
int8 refactor: initial sparse decomp, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 14, 2024
1 parent 57e6427 commit ca372f2
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 434 deletions.
142 changes: 54 additions & 88 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import reduce # Required in Python 3
import operator
from math import prod
from typing import Callable, Optional, Tuple
import warnings
from warnings import warn
Expand All @@ -9,12 +8,6 @@

import bitsandbytes.functional as F


# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)


# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py

Expand Down Expand Up @@ -284,10 +277,16 @@ def tile_indices(self):

class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
state = state or MatmulLtState()
def forward(
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
B: torch.Tensor,
out=None,
bias: Optional[torch.Tensor] = None,
state=MatmulLtState,
):
# state = state or MatmulLtState()

using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
Expand All @@ -300,14 +299,7 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
else:
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)

# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()

# Cast A to fp16
if A.dtype != torch.float16:
Expand All @@ -318,92 +310,57 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)

if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
idx = torch.unique(coo_tensorA.colidx).long()
CA[:, idx] = 0
CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
subA = None
has_grad = False

# 2. Quantize B
if state.has_fp16_weights:
has_grad = True if (getattr(B, "grad", None) is not None) else False
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()

if (state.is_training and not has_grad) or state.CB is None:
state.reset_grads()

# quantize...
# 2. Quantize B
(
state.CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
_,
) = F.double_quant(B.to(torch.float16))
else:
has_grad = False

if coo_tensorA is not None and not state.has_fp16_weights:
# extract outliers

outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx

# if state.CxB is not None:
# outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
# else:
outliers = state.CB[:, state.idx.long()].clone()

state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]

shapeB = state.CB.shape

if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0])
else:
output_shape = (input_shape[0], shapeB[0])

# 3. Matmul
if using_igemmlt:
out32, Sout32 = F.igemmlt(CA, state.CB)
if state.threshold > 0.0 and coo_tensorA is not None:
state.idx = torch.unique(coo_tensorA._indices()[1]).long()

# Zero out the outliers in the int8 inputs
CA[:, state.idx] = 0
CAt[:, state.idx] = 0

if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
# Extract the input outliers in original precision
subA = A[:, state.idx]

# Extract the corresponding weights
if state.has_fp16_weights:
state.subB = B[:, state.idx].t().contiguous()
else:
outliers = state.CB[:, state.idx].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
else:
A_wo_outliers = A.clone()
if state.idx is not None:
A_wo_outliers[:, state.idx.long()] = 0
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
if bias is not None:
output = output.add_(bias)
subA = state.subB = None

# 3. Int8 Matmul
out32, Sout32 = F.igemmlt(CA, state.CB)
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)

# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
output += torch.matmul(subA, state.subB)
if subA is not None and state.subB is not None:
output += torch.matmul(subA, state.subB.to(subA.dtype))

# 5. Save state
ctx.state = state
Expand All @@ -419,7 +376,8 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)

return output.reshape(output_shape)
output_shape = (*input_shape[:-1], state.CB.shape[0])
return output.reshape(output_shape).clone()

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -441,16 +399,24 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
gradB32, SgradB32 = F.igemmlt(Cgradt.t(), CAt.t())
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
# grad_output.T @ A
# grad_weight = grad_output.t().mm(A)
grad_B = torch.matmul(grad_output.t(), A)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
# if req_gradB:
#
# gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
# if state.threshold > 0.0 and subA is not None:
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
# grad_output @ B.T
if state.CBt is not None:
gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t())
gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __init__(self, lib: ct.CDLL):
def __getattr__(self, item):
return getattr(self._lib, item)

def __getitem__(self, item):
return getattr(self._lib, item)


class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True
Expand Down
Loading

0 comments on commit ca372f2

Please sign in to comment.