Skip to content

Commit

Permalink
Replace itrex qbits to ipex woq linear (#549)
Browse files Browse the repository at this point in the history
Co-authored-by: Casper <casperbh.96@gmail.com>
  • Loading branch information
jiqing-feng and casper-hansen authored Sep 12, 2024
1 parent d5ec43d commit eab1a4a
Show file tree
Hide file tree
Showing 17 changed files with 279 additions and 326 deletions.
35 changes: 17 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,26 +228,25 @@ GPU: 2x NVIDIA GeForce RTX 4090

### CPU

- CPU: INTEL(R) XEON(R) PLATINUM 8592+ with 8-channel 4800MT/s memory.
- CPU: 48 cores SPR (Intel 4th Gen Xeon CPU)
- Command: `python examples/benchmark.py --model_path <hf_model> --batch_size 1`

| Model | Size | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (RAM) |
|--------:|------:|-----------:|-------------:|-----------------:|----------------:|---------------:|:------------------|
| Mixtral | 7B | 1 | 64 | 64 | 389.24 | 16.01 | 5.59 GB (0.02%) |
| Mixtral | 7B | 1 | 2048 | 2048 | 1412 | 17.76 | 6.29 GB (0.03%) |
| Vicuna | 7B | 1 | 64 | 64 | 346 | 18.13 | 8.18 GB (0.03%) |
| Vicuna | 7B | 1 | 2048 | 2048 | 1023.4 | 18.18 | 8.80 GB (0.04%) |
| LLaMA2 | 13B | 1 | 64 | 64 | 160.24 | 9.87 | 14.65 GB (0.06%) |
| LLaMA2 | 13B | 1 | 2048 | 2048 | 592.35 | 9.93 | 16.87 GB (0.07%) |
| Mosaicml | 7B | 1 | 64 | 64 | 433.17 | 18.79 | 4.60 GB (0.02%) |
| Mosaicml | 7B | 1 | 2048 | 2048 | 404.25 | 19.91 | 4.75 GB (0.02%) |
| Falcon | 7B | 1 | 64 | 64 | 303.16 | 14.41 | 5.18 GB (0.02%) |
| Falcon | 7B | 1 | 2048 | 2048 | 634.57 | 15.55 | 5.80 GB (0.02%) |
| CodeLlama | 34B | 1 | 64 | 64 | 153.73 | 4.23 | 29.00 GB (0.12%) |
| CodeLlama | 34B | 1 | 2048 | 2048 | 274.25 | 4.38 | 35.21 GB (0.15%) |
| Deepseek-coder | 33B | 1 | 64 | 64 | 83.08 | 4.07 | 22.16 GB (0.09%) |
| Deepseek-coder | 33B | 1 | 2048 | 2048 | 296.04 | 4.33 | 37.05 GB |

| Model | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory |
|-------|---------|------------|----------------|---------------|-------------------|------------------|---------------|
| Llama 2 7B | gemm | 1 | 32 | 32 | 817.86 | 70.93 | 1.94 GB (0.00%) |
| Llama 2 7B | gemm | 1 | 2048 | 2048 | 5279.15 | 36.83 | 2.31 GB (0.00%) |
| Falcon | gemm | 1 | 32 | 32 | 337.51 | 26.41 | 9.57 GB (0.01%) |
| Falcon | gemm | 1 | 2048 | 2048 | 546.71 | 18.8 | 13.46 GB (0.01%) |
| Mistral | gemm | 1 | 32 | 32 | 343.08 | 28.46 | 9.74 GB (0.01%) |
| Mistral | gemm | 1 | 2048 | 2048 | 1135.23 | 13.23 | 10.35 GB (0.01%) |
| Vicuna | gemm | 1 | 32 | 32 | 340.73 | 28.86 | 9.59 GB (0.01%) |
| Vicuna | gemm | 1 | 2048 | 2048 | 1143.19 | 11.14 | 10.98 GB (0.01%) |
| Llama 2 13B | gemm | 1 | 32 | 32 | 220.79 | 18.14 | 17.46 GB (0.02%) |
| Llama 2 13B | gemm | 1 | 2048 | 2048 | 650.94 | 6.54 | 19.84 GB (0.02%) |
| DeepSeek Coder 33B | gemm | 1 | 32 | 32 | 101.61 | 8.58 | 40.80 GB (0.04%) |
| DeepSeek Coder 33B | gemm | 1 | 2048 | 2048 | 245.02 | 3.48 | 41.72 GB (0.04%) |
| Phind CodeLlama 34B | gemm | 1 | 32 | 32 | 102.47 | 9.04 | 41.70 GB (0.04%) |
| Phind CodeLlama 34B | gemm | 1 | 2048 | 2048 | 237.57 | 3.48 | 42.47 GB (0.04%) |

## Reference

Expand Down
4 changes: 2 additions & 2 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def from_quantized(
fuse_layers=True,
use_exllama=False,
use_exllama_v2=False,
use_qbits=False,
use_ipex=False,
batch_size=1,
safetensors=True,
device_map="balanced",
Expand Down Expand Up @@ -116,7 +116,7 @@ def from_quantized(
fuse_layers=fuse_layers,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
use_qbits=use_qbits,
use_ipex=use_ipex,
safetensors=safetensors,
device_map=device_map,
max_memory=max_memory,
Expand Down
51 changes: 21 additions & 30 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@
from awq.modules.linear import (
WQLinear_GEMM,
WQLinear_GEMV,
WQLinear_QBits,
WQLinear_IPEX,
WQLinear_Marlin,
WQLinear_Exllama,
WQLinear_ExllamaV2,
WQLinear_GEMVFast,
marlin_post_init,
exllama_post_init,
exllamav2_post_init,
qbits_post_init,
ipex_post_init,
)
from awq.utils.module import (
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
try_import,
)
from awq.utils.utils import get_best_device, qbits_available
from awq.utils.utils import get_best_device, ipex_available
from transformers import (
AutoConfig,
PreTrainedModel,
Expand All @@ -52,9 +52,6 @@
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name

if qbits_available:
from intel_extension_for_transformers.qbits import check_isa_supported


# Since we support different `AutoModelForxxx` from transformers
# we need to define a custom mapping dict as below:
Expand Down Expand Up @@ -440,8 +437,8 @@ def from_quantized(
use_exllama_v2: Annotated[
bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
] = False,
use_qbits: Annotated[
bool, Doc("Whether to map the weights to qbits kernels for CPU device.")
use_ipex: Annotated[
bool, Doc("Whether to map the weights to ipex kernels for CPU device.")
] = False,
device_map: Annotated[
Union[str, Dict],
Expand Down Expand Up @@ -494,17 +491,11 @@ def from_quantized(
trust_remote_code=trust_remote_code,
)

use_cpu_qbits = use_qbits or get_best_device() == "cpu"
if use_cpu_qbits:
if not qbits_available:
raise ImportError(
"Please install intel-extension-for-transformers with "
"`pip install intel-extension-for-transformers` for 'qbits' kernel!"
)

fuse_layers = False
logging.warn(
"Unsupport fuse_layers featrue for CPU device with QBits backend!"
use_cpu_ipex = use_ipex or get_best_device() == "cpu"
if use_cpu_ipex and not ipex_available:
raise ImportError(
"Please install intel_extension_for_pytorch with "
"`pip install intel_extension_for_pytorch` for 'ipex' kernel!"
)
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(
Expand All @@ -514,7 +505,7 @@ def from_quantized(
quant_config.version,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
use_qbits=use_cpu_qbits,
use_ipex=use_cpu_ipex,
)

model.tie_weights()
Expand All @@ -539,11 +530,11 @@ def from_quantized(
else:
self.fuse_layers(model)

if use_cpu_qbits:
dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32
if use_cpu_ipex:
dtype = torch.bfloat16
model.to(dtype=dtype, device="cpu")
# repack qweight to match the QBits kernel.
model = qbits_post_init(model)
# repack qweight to match the ipex kernel.
model = ipex_post_init(model)
elif quant_config.version == "marlin":
model = marlin_post_init(model)
elif use_exllama:
Expand Down Expand Up @@ -631,11 +622,11 @@ def _load_config(
return model_weights_path, config, quant_config

def _load_quantized_modules(
self, model, quant_config, version, use_exllama, use_exllama_v2, use_qbits=False
self, model, quant_config, version, use_exllama, use_exllama_v2, use_ipex=False
):
# Real quantization of weights
assert not (
version == "gemv" and (use_exllama or use_exllama_v2 or use_qbits)
version == "gemv" and (use_exllama or use_exllama_v2 or use_ipex)
), "Exllama kernels only support GEMM version."

# Get blocks of model
Expand All @@ -657,8 +648,8 @@ def _load_quantized_modules(

# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if use_qbits:
q_linear_module = WQLinear_QBits
if use_ipex:
q_linear_module = WQLinear_IPEX
elif version == "marlin":
q_linear_module = WQLinear_Marlin
elif use_exllama:
Expand All @@ -672,7 +663,7 @@ def _load_quantized_modules(
elif version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

if use_qbits:
if use_ipex:
q_linear = q_linear_module.from_linear(
module,
quant_config.w_bit,
Expand All @@ -687,7 +678,7 @@ def _load_quantized_modules(
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)

if not use_qbits:
if not use_ipex:
torch.cuda.empty_cache()
gc.collect()

Expand Down
48 changes: 32 additions & 16 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def __init__(
self.is_neox = kwargs["is_neox"]

self.attn_logit_softcapping = attn_logit_softcapping
self.use_sdpa = kwargs.get("use_sdpa", False)

def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
Expand Down Expand Up @@ -266,29 +267,44 @@ def forward(
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

# Used in Gemma2
if self.attn_logit_softcapping is not None:
scores = scores / self.attn_logit_softcapping
scores = torch.tanh(scores)
scores = scores * self.attn_logit_softcapping

if self.use_alibi:
scores = self.alibi.forward(scores, seqlen)

# When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1:
# For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we
# need to slice it
if attention_mask.shape[-1] != seqlen:
attention_mask = attention_mask[:, :, :seqlen, :seqlen]

scores = (
scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
if self.use_sdpa:
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : keys.shape[-2]]
is_causal = True if causal_mask is None and seqlen > 1 else False
output = torch.nn.functional.scaled_dot_product_attention(
xq,
keys,
values,
attn_mask=causal_mask,
dropout_p=0.0,
is_causal=is_causal,
)
else:
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores = self.alibi.forward(scores, seqlen)

# When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1:
# For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we
# need to slice it
if attention_mask.shape[-1] != seqlen:
attention_mask = attention_mask[:, :, :seqlen, :seqlen]

scores = (
scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)

attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
Expand Down
1 change: 1 addition & 0 deletions awq/modules/fused/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
head_dim=head_dim,
use_sdpa=True,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
Expand Down
23 changes: 16 additions & 7 deletions awq/modules/fused/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
except:
AWQ_INSTALLED = False

try:
import intel_extension_for_pytorch as ipex # with IPEX kernels

IPEX_INSTALLED = True
except:
IPEX_INSTALLED = False


class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
Expand All @@ -16,12 +23,14 @@ def __init__(self, weight, eps=1e-6):
self.variance_epsilon = eps

def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)

output = torch.empty_like(x)
awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
if IPEX_INSTALLED:
output = ipex.llm.functional.rms_norm(x, self.weight, self.variance_epsilon)
else:
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
output = torch.empty_like(x)
awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)

return output
2 changes: 1 addition & 1 deletion awq/modules/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .exllama import WQLinear_Exllama, exllama_post_init
from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from .gemm import WQLinear_GEMM
from .gemm_qbits import WQLinear_QBits, qbits_post_init
from .gemm_ipex import WQLinear_IPEX, ipex_post_init
from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin, marlin_post_init
from .gemv_fast import WQLinear_GEMVFast
Loading

0 comments on commit eab1a4a

Please sign in to comment.