Skip to content

Commit

Permalink
New naive mm_dequant kernel for row-major; cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 9, 2024
1 parent 0f2dc34 commit 50fe50e
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 276 deletions.
28 changes: 5 additions & 23 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def tile_indices(self):

class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None):
state = state or MatmulLtState()

using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
Expand Down Expand Up @@ -417,8 +419,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)

clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
return clone_func(output.view(output_shape))
return output.reshape(output_shape)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -442,37 +443,18 @@ def backward(ctx, grad_output):

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
# CxAt, SAt = F.transform(CAt, formatB, transpose=True)
# C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
# gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
gradB32, SgradB32 = F.igemmlt(
Cgradt.t(), CAt.t()
) # issue here in test_linear_serialization w/ has fp16 weights
gradB32, SgradB32 = F.igemmlt(Cgradt.t(), CAt.t())
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
if state.CBt is not None:
# C32grad, Sgrad = F.transform(Cgrad, "col32")
# if state.CxBt is None:
# state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
# gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
gradA32, SgradA32 = F.igemmlt(Cgradt, state.CBt.t())
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)

elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
# elif state.CxB is not None:
# CB = (
# undo_layout(state.CxB, state.tile_indices)
# .to(ctx.dtype_A)
# .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
# )
# grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception("State must contain either CBt or CB matrix for backward")

Expand Down
111 changes: 75 additions & 36 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,7 +2330,7 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32):
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
ldc = shapeC[-1] # Output (batch, tokens, outputs)

assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ B={shapeA}"
assert lda == ldb, f"igemmlt only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"

prev_device = A.device
torch.cuda.set_device(A.device)
Expand Down Expand Up @@ -2361,18 +2361,25 @@ def igemmlt(A, B, out=None, Sout=None, dtype=torch.int32):
return out, Sout


def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
def mm_dequant_torch(
A: torch.Tensor,
quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format)
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
new_row_stats=None, # TODO: unused
new_col_stats=None, # TODO: unused
bias: Optional[torch.Tensor] = None,
):
assert A.dtype == torch.int32

compute_dtype = torch.float32

A_calc = A.view(-1, A.shape[-1]).to(compute_dtype)
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
A_calc = A.view(-1, A.shape[-1])
row_stats = row_stats.reshape(-1).unsqueeze(-1)
col_stats = col_stats.reshape(-1).unsqueeze(0)

# TODO support out != None

out = A_calc * (row_stats * col_stats) * 6.200124e-5 # .to(torch.float16)
out = A_calc * (row_stats * col_stats) * 6.200124e-5

if bias is not None:
# assert bias.dtype == torch.float16
Expand All @@ -2381,42 +2388,40 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non
return out.to(torch.float16)


def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
def mm_dequant(
A: torch.Tensor,
quant_state: Optional[Tuple[torch.Size, str]], # TODO: deprecate. (shape, format)
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
new_row_stats=None, # TODO: unused
new_col_stats=None, # TODO: unused
bias: Optional[torch.Tensor] = None,
):
assert A.dtype == torch.int32

if bias is not None:
assert bias.dtype == torch.float16
out_shape = quant_state[0]
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])

if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None:
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
if new_col_stats is None:
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
out = torch.empty_like(A, dtype=torch.float16)

prev_device = pre_call(A.device)
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
ptrNewRowStats = get_ptr(new_row_stats)
ptrNewColStats = get_ptr(new_col_stats)
ptrBias = get_ptr(bias)
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
numRows = ct.c_int32(prod(A.shape[:-1]))
numCols = ct.c_int32(A.shape[-1])

is_on_gpu([A, row_stats, col_stats, out, bias])

is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
prev_device = pre_call(A.device)
lib.cdequant_mm_int32_fp16(
ptrA,
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
ptrBias,
numRows,
numCols,
Expand All @@ -2426,7 +2431,33 @@ def mm_dequant_old(A, quant_state, row_stats, col_stats, out=None, new_row_stats
return out


def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
def get_colrow_absmax(
A: torch.Tensor,
row_stats: torch.Tensor = None,
col_stats: torch.Tensor = None,
nnz_block_ptr: torch.Tensor = None,
threshold=0.0,
):
# Note: prior impl only works with fp16
assert A.is_floating_point()

if row_stats is None or col_stats is None:
absA = A.abs().view(-1, A.shape[-1]) # view as 2D
if row_stats is None:
# shape [rows]; unsqueeze(-1) gives [rows,1]
row_stats = absA.amax(dim=1, keepdim=False).float()
if col_stats is None:
# shape [cols]; unsqueeze(0) gives [1,cols]
col_stats = absA.amax(dim=0, keepdim=False).float()

# TODO: threshold support
if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros_like(A, dtype=torch.int32)

return row_stats, col_stats, nnz_block_ptr


def get_colrow_absmax_old(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
assert A.dtype == torch.float16
device = A.device

Expand Down Expand Up @@ -2543,19 +2574,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)


@torch.compile
def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
# TODO: Optimize/write CUDA kernel for this. Currently vectorwise_quant will recalculate row/col stats.
# TODO: Optimize/write CUDA kernel for this
# TODO: Support threshold

# if out_col is None:
# out_col = torch.zeros(A.shape, device=A.device, dtype=torch.int8)
# if out_row is None:
# out_row = torch.zeros(A.shape, device=A.device, dtype=torch.int8)
if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)

scaled_A = A.mul(C)

# quant_row = torch.round(A * (C / row_stats.unsqueeze(-1))).to(torch.int8)
# quant_col = torch.round(A * (C / col_stats.unsqueeze(0))).to(torch.int8)
quant_row = torch.round(scaled_A / row_stats.unsqueeze(-1)).to(torch.int8)
quant_col = torch.round(scaled_A / col_stats.unsqueeze(0)).to(torch.int8)

out_col, Scol = vectorwise_quant(A, dim=0)
out_row, Srow = vectorwise_quant(A, dim=1)
if out_row is not None:
quant_row = out_row.copy_(quant_row)
if out_col is not None:
quant_col = out_col.copy_(quant_col)

return out_row, out_col, Srow.flatten().float(), Scol.flatten().float(), None # coo_tensor
return quant_row, quant_col, row_stats.flatten().float(), col_stats.flatten().float(), None


def double_quant_old(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
Expand Down
Loading

0 comments on commit 50fe50e

Please sign in to comment.