diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index aff1e279..1eb942c2 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -63,9 +63,16 @@ def forward( def backward(ctx, grad_output): input, qweight, qzeros, scales, bias = ctx.saved_tensors + if awq_ext is None: + raise ValueError( + "auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels" + " by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels" + ) + + # Cast to correct dtype for mixed precision training weights = awq_ext.dequantize_weights_cuda( qweight, scales, qzeros, 1, 0, 0, False - ) + ).to(grad_output.dtype) if ctx.needs_input_grad[0]: # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm @@ -75,7 +82,6 @@ def backward(ctx, grad_output): return grad_input, None, None, None, None, None, None, None - class WQLinear_GEMM(nn.Module): def __init__( self, w_bit, group_size, in_features, out_features, bias, dev, training=False