Skip to content

Commit

Permalink
Enable Intel GPU path and lora finetune and change examples to suppor…
Browse files Browse the repository at this point in the history
…t different devices (#631)

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Casper <casperbh.96@gmail.com>
  • Loading branch information
jiqing-feng and casper-hansen authored Nov 14, 2024
1 parent b42e3c3 commit 419a242
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 45 deletions.
29 changes: 19 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
- Your CUDA version must be CUDA 11.8 or later.
- AMD:
- Your ROCm version must be compatible with Triton.
- Intel CPU and Intel GPU:
- Your torch and intel_extension_for_pytorch package version should at least 2.4 for optimized performance.

### Install from PyPi

Expand All @@ -60,6 +62,10 @@ There are a few ways to install AutoAWQ:
- `INSTALL_KERNELS=1 pip install git+https://github.com/casper-hansen/AutoAWQ.git`
- NOTE: This installs https://github.com/casper-hansen/AutoAWQ_kernels

3. From main branch for Intel CPU and Intel XPU optimized performance:
- `pip install intel_extension_for_pytorch`
- `pip install git+https://github.com/casper-hansen/AutoAWQ.git`

## Usage

Under examples, you can find examples of how to quantize, run inference, and benchmark AutoAWQ models.
Expand Down Expand Up @@ -132,6 +138,9 @@ print(f'Model is quantized and saved at "{quant_path}"')
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
from awq.utils.utils import get_best_device

device = get_best_device()

quant_path = "TheBloke/zephyr-7B-beta-AWQ"

Expand All @@ -155,7 +164,7 @@ prompt = "You're standing on the surface of the Earth. "\
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
).input_ids.to(device)

# Generate output
generation_output = model.generate(
Expand Down Expand Up @@ -229,18 +238,18 @@ GPU: 2x NVIDIA GeForce RTX 4090
### CPU

- CPU: 48 cores SPR (Intel 4th Gen Xeon CPU)
- Command: `python examples/benchmark.py --model_path <hf_model> --batch_size 1`
- Command: `python examples/benchmark.py --model_path <hf_model> --batch_size 1 --generator hf`

| Model | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory |
|-------|---------|------------|----------------|---------------|-------------------|------------------|---------------|
| Llama 2 7B | gemm | 1 | 32 | 32 | 817.86 | 70.93 | 1.94 GB (0.00%) |
| Llama 2 7B | gemm | 1 | 2048 | 2048 | 5279.15 | 36.83 | 2.31 GB (0.00%) |
| Falcon | gemm | 1 | 32 | 32 | 337.51 | 26.41 | 9.57 GB (0.01%) |
| Falcon | gemm | 1 | 2048 | 2048 | 546.71 | 18.8 | 13.46 GB (0.01%) |
| Mistral | gemm | 1 | 32 | 32 | 343.08 | 28.46 | 9.74 GB (0.01%) |
| Mistral | gemm | 1 | 2048 | 2048 | 1135.23 | 13.23 | 10.35 GB (0.01%) |
| Vicuna | gemm | 1 | 32 | 32 | 340.73 | 28.86 | 9.59 GB (0.01%) |
| Vicuna | gemm | 1 | 2048 | 2048 | 1143.19 | 11.14 | 10.98 GB (0.01%) |
| TinyLlama 1B | gemm | 1 | 32 | 32 | 817.86 | 70.93 | 1.94 GB (0.00%) |
| TinyLlama 1B | gemm | 1 | 2048 | 2048 | 5279.15 | 36.83 | 2.31 GB (0.00%) |
| Falcon 7B | gemm | 1 | 32 | 32 | 337.51 | 26.41 | 9.57 GB (0.01%) |
| Falcon 7B | gemm | 1 | 2048 | 2048 | 546.71 | 18.8 | 13.46 GB (0.01%) |
| Mistral 7B | gemm | 1 | 32 | 32 | 343.08 | 28.46 | 9.74 GB (0.01%) |
| Mistral 7B | gemm | 1 | 2048 | 2048 | 1135.23 | 13.23 | 10.35 GB (0.01%) |
| Vicuna 7B | gemm | 1 | 32 | 32 | 340.73 | 28.86 | 9.59 GB (0.01%) |
| Vicuna 7B | gemm | 1 | 2048 | 2048 | 1143.19 | 11.14 | 10.98 GB (0.01%) |
| Llama 2 13B | gemm | 1 | 32 | 32 | 220.79 | 18.14 | 17.46 GB (0.02%) |
| Llama 2 13B | gemm | 1 | 2048 | 2048 | 650.94 | 6.54 | 19.84 GB (0.02%) |
| DeepSeek Coder 33B | gemm | 1 | 32 | 32 | 101.61 | 8.58 | 40.80 GB (0.04%) |
Expand Down
15 changes: 7 additions & 8 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def from_quantized(
bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
] = False,
use_ipex: Annotated[
bool, Doc("Whether to map the weights to ipex kernels for CPU device.")
bool, Doc("Whether to map the weights to ipex kernels for CPU and XPU device.")
] = False,
device_map: Annotated[
Union[str, Dict],
Expand Down Expand Up @@ -500,8 +500,9 @@ def from_quantized(
trust_remote_code=trust_remote_code,
)

use_cpu_ipex = use_ipex or get_best_device() == "cpu"
if use_cpu_ipex and not ipex_available:
best_device = get_best_device()
use_ipex = use_ipex or best_device in ["cpu", "xpu:0"]
if use_ipex and not ipex_available:
raise ImportError(
"Please install intel_extension_for_pytorch with "
"`pip install intel_extension_for_pytorch` for 'ipex' kernel!"
Expand All @@ -514,7 +515,7 @@ def from_quantized(
quant_config.version,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
use_ipex=use_cpu_ipex,
use_ipex=use_ipex,
)

model.tie_weights()
Expand All @@ -534,14 +535,12 @@ def from_quantized(
# Dispath to devices
awq_ext, msg = try_import("awq_ext")
if fuse_layers:
if awq_ext is None:
if best_device in ["mps", "cuda:0"] and awq_ext is None:
warnings.warn("Skipping fusing modules because AWQ extension is not installed." + msg)
else:
self.fuse_layers(model)

if use_cpu_ipex:
dtype = torch.bfloat16
model.to(dtype=dtype, device="cpu")
if use_ipex:
# repack qweight to match the ipex kernel.
model = ipex_post_init(model)
elif quant_config.version == "marlin":
Expand Down
6 changes: 3 additions & 3 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, head_dim, max_seq_len, device, rope_theta):
super(RoPE, self).__init__()

self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(head_dim, max_seq_len * 2, rope_theta).to(device),
self.precompute_freqs_cis(head_dim, max_seq_len, rope_theta).to(device),
requires_grad=False,
)

Expand Down Expand Up @@ -137,8 +137,8 @@ def __init__(
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))

if kwargs.get("max_new_tokens") is not None:
max_seq_len = kwargs["max_new_tokens"]
if kwargs.get("max_length") is not None:
max_seq_len = kwargs["max_length"]

self.max_seq_len = max_seq_len
self.is_hf_transformers = False
Expand Down
45 changes: 34 additions & 11 deletions awq/modules/linear/gemm_ipex.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import torch
import torch.nn as nn
from .gemm import WQLinear_GEMM
from awq.utils.packing_utils import dequantize_gemm

try:
from intel_extension_for_pytorch.nn.modules.weight_only_quantization import WeightOnlyQuantizedLinear
assert hasattr(WeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4"
from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear
assert hasattr(IPEXWeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4"
IPEX_INSTALLED = True
except:
IPEX_INSTALLED = False


class WQLinear_IPEX(nn.Module):
class WQLinear_IPEX(WQLinear_GEMM):

def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False):
nn.Module.__init__(self)
assert IPEX_INSTALLED, \
"Please install IPEX package with `pip install intel_extension_for_pytorch`."
assert w_bit == 4, "Only 4 bit are supported for now."
Expand All @@ -24,12 +26,15 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.scale_dtype = torch.float32
self.training = training

# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.pack_num = 32 // self.w_bit

self.init_ipex = False

self.register_buffer(
"qzeros",
torch.zeros(
Expand Down Expand Up @@ -59,10 +64,13 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
self.register_buffer("qweight", qweight)

def post_init(self):
assert self.qweight.device.type == "cpu"
self.ipex_linear = WeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \
self.in_features, self.out_features, None, self.bias, \
self.group_size, None, 0, 1)
assert self.qweight.device.type in ("cpu", "xpu")

def init_ipex_linear(self):
if not self.training:
self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \
self.in_features, self.out_features, None, self.bias, \
self.group_size, None, quant_method=1, dtype=0)

@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None):
Expand All @@ -79,16 +87,31 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None):

raise NotImplementedError("Only inference is supported for IPEX kernels")

@torch.no_grad()
def forward(self, x):
assert IPEX_INSTALLED, (
"IPEX kernels could not be loaded. "
"Please install with `pip install intel_extension_for_pytorch` and "
"refer to the detial https://github.com/intel/intel-extension-for-pytorch/tree/main")

outputs = self.ipex_linear(x)
if not self.init_ipex:
self.init_ipex_linear()
self.init_ipex = True

if hasattr(self, "ipex_linear"):
with torch.no_grad():
outputs = self.ipex_linear(x)
else:
outputs = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(x.dtype)
outputs = torch.matmul(x, outputs)

return outputs

def backward(self, grad_output):
weights = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size).to(grad_output.dtype)
batch_size = grad_output.shape[0]
grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1))

return grad_input, None, None, None, None, None, None, None

def extra_repr(self) -> str:
return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
Expand Down
6 changes: 3 additions & 3 deletions awq/utils/fused_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
else None
)

if isinstance(q_proj, WQLinear_GEMV):
if isinstance(q_proj, WQLinear_IPEX):
q_linear = WQLinear_IPEX
elif isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
elif isinstance(q_proj, WQLinear_GEMM):
q_linear = WQLinear_GEMM
Expand All @@ -79,8 +81,6 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
q_linear = WQLinear_Marlin
elif isinstance(q_proj, WQLinear_GEMVFast):
q_linear = WQLinear_GEMVFast
elif isinstance(q_proj, WQLinear_IPEX):
q_linear = WQLinear_IPEX

qkv_layer = q_linear(
q_proj.w_bit,
Expand Down
2 changes: 2 additions & 0 deletions awq/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def get_best_device():
return "mps"
elif torch.cuda.is_available():
return "cuda:0"
elif torch.xpu.is_available():
return "xpu:0"
else:
return "cpu"

Expand Down
35 changes: 26 additions & 9 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from transformers import AutoTokenizer, GenerationConfig, LogitsProcessor, LogitsProcessorList

DEVICE = get_best_device()
if DEVICE == "cpu":
if DEVICE in ["cpu", "xpu:0"]:
if ipex_available:
torch_dtype = torch.bfloat16
torch_dtype = torch.bfloat16 if DEVICE == "cpu" else torch.float16
else:
raise ImportError("Please import intel_extension_for_pytorch "
"by `pip install intel_extension_for_pytorch`")
Expand All @@ -29,8 +29,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
"""The logit processor is called after the model forward."""

# cuda runs async operates, so we synchronize for accurate time measurement
if DEVICE != "cpu":
if DEVICE == "cuda:0":
torch.cuda.synchronize()
elif DEVICE == "xpu:0":
torch.xpu.synchronize()

# measure time
start_time = time.time()
Expand All @@ -56,8 +58,10 @@ def generate_torch(model, input_ids, n_generate):

with torch.inference_mode():
for i in range(n_generate):
if DEVICE != "cpu":
if DEVICE == "cuda:0":
torch.cuda.synchronize()
elif DEVICE == "xpu:0":
torch.xpu.synchronize()
start = time.time()

if i == 0:
Expand All @@ -69,8 +73,10 @@ def generate_torch(model, input_ids, n_generate):

out = model(inputs, use_cache=True)

if DEVICE != "cpu":
if DEVICE == "cuda:0":
torch.cuda.synchronize()
elif DEVICE == "xpu:0":
torch.xpu.synchronize()
token = out[0][:, -1].max(1)[1].unsqueeze(1)

if i == 0:
Expand Down Expand Up @@ -102,7 +108,7 @@ def generate_hf(model: BaseAWQForCausalLM, input_ids, n_generate):

return context_time, generate_time

def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_size, no_safetensors, pretrained):
def run_round(generator, model_path, quant_file, n_generate, context, input_ids, batch_size, no_safetensors, pretrained):
print(f" -- Loading model...")

if pretrained:
Expand All @@ -114,7 +120,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si
)
else:
model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, max_seq_len=n_generate, batch_size=batch_size, safetensors=not no_safetensors
model_path, quant_file, max_seq_len=n_generate+context, batch_size=batch_size, safetensors=not no_safetensors
)

print(f" -- Warming up...")
Expand Down Expand Up @@ -149,6 +155,12 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si
memory_pct = mem_info.rss / memory_info.total
total_memory_used = float(mem_info.rss) / (1024 ** 3)
print(f" ** Max Memory (device: {DEVICE}): {total_memory_used:.2f} GB ({memory_pct:.2f}%)")
elif DEVICE == "xpu:0":
for device in range(torch.xpu.device_count()):
memory_used = torch.xpu.max_memory_allocated(device) / (1024 ** 3)
total_memory_used += memory_used
memory_pct = memory_used / (torch.xpu.get_device_properties(device).total_memory / (1024 ** 3)) * 100
print(f" ** Max Memory (device: {device}): {memory_used:.2f} GB ({memory_pct:.2f}%)")
else:
for device in range(torch.cuda.device_count()):
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
Expand Down Expand Up @@ -197,14 +209,17 @@ def main(args):

for settings in rounds:
input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"]))
if DEVICE != "cpu":
if DEVICE == "cuda:0":
input_ids = input_ids.cuda()
elif DEVICE == "xpu:0":
input_ids = input_ids.to("xpu:0")

stats, model_version = run_round(
generator,
args.model_path,
args.quant_file,
settings["n_generate"],
settings["context"],
input_ids,
args.batch_size,
args.no_safetensors,
Expand All @@ -218,8 +233,10 @@ def main(args):

df = pd.DataFrame(all_stats)
print('Device:', DEVICE)
if DEVICE != "cpu":
if DEVICE == "cuda:0":
print('GPU:', torch.cuda.get_device_name())
elif DEVICE == "xpu:0":
print('XPU:', torch.xpu.get_device_name())
print('Model:', args.model_path)
print('Version:', model_version)
print(df.to_markdown(index=False))
Expand Down
4 changes: 3 additions & 1 deletion examples/generate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
from awq.utils.utils import get_best_device

device = get_best_device()
model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4"
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
Expand All @@ -26,7 +28,7 @@
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to("cuda")
).to(device)

outputs = model.generate(
**inputs,
Expand Down

0 comments on commit 419a242

Please sign in to comment.