From c025b157d3cd112b97460178d4d9340fff826277 Mon Sep 17 00:00:00 2001 From: Casper Date: Tue, 2 Jul 2024 14:57:35 +0200 Subject: [PATCH] Batched quantization (#516) --- awq/models/base.py | 79 ++++++++++++--- awq/quantize/quantizer.py | 203 ++++++++++++++++++++++++++++---------- awq/utils/calib_data.py | 12 +-- docs/examples.md | 103 +++++++++++++++++++ 4 files changed, 321 insertions(+), 76 deletions(-) diff --git a/awq/models/base.py b/awq/models/base.py index 0f641160..66dc02e6 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -81,7 +81,7 @@ "phi3": "AutoModelForCausalLM", "cohere": "AutoModelForCausalLM", "deepseek_v2": "AutoModelForCausalLM", - "minicpm":"AutoModelForCausalLM", + "minicpm": "AutoModelForCausalLM", } @@ -156,6 +156,34 @@ def quantize( "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False." ), ] = True, + n_parallel_calib_samples: Annotated[ + int, + Doc( + "The number of parallel samples to run through the model. " + "A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. " + "If None, runs through all samples at the same time. " + "You can set this to a low number for more memory efficient quantization." + ), + ] = None, + max_calib_samples: Annotated[ + int, Doc("The maximum number of samples to run through the model.") + ] = 128, + max_calib_seq_len: Annotated[ + int, + Doc( + "The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len." + ), + ] = 512, + max_chunk_memory: Annotated[ + int, + Doc( + "The loss computation and per-channel mean is optimized into chunked computations." + " Adjust this parameter to increase or decrease memory usage for these computations." + " Default is 1GB (1024 * 1024 * 1024)." + ), + ] = 1024 + * 1024 + * 1024, ): """ The main quantization function that you can use to quantize your model. @@ -194,6 +222,10 @@ def quantize( modules_to_not_convert=self.quant_config.modules_to_not_convert, export_compatible=export_compatible, apply_clip=apply_clip, + n_parallel_calib_samples=n_parallel_calib_samples, + max_calib_samples=max_calib_samples, + max_calib_seq_len=max_calib_seq_len, + max_chunk_memory=max_chunk_memory, ) self.quantizer.quantize() @@ -312,7 +344,8 @@ def from_pretrained( ), ] = None, download_kwargs: Annotated[ - Dict, Doc("Used for configure download model"), + Dict, + Doc("Used for configure download model"), ] = None, **model_init_kwargs: Annotated[ Dict, @@ -324,9 +357,12 @@ def from_pretrained( """A method for initialization of pretrained models, usually in FP16.""" # Get weights path and quant config model_weights_path, config, quant_config = self._load_config( - self, model_path, "", safetensors, + self, + model_path, + "", + safetensors, trust_remote_code=trust_remote_code, - download_kwargs=download_kwargs + download_kwargs=download_kwargs, ) target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type] @@ -409,7 +445,7 @@ def from_quantized( ), ] = "balanced", max_memory: Annotated[ - Dict[Union[int, str], Union[int, str]], + Dict[Union[int, str], Union[int, str]], Doc( 'A dictionary device identifier to maximum memory which will be passed onto the model loading method from transformers. For example:{0: "4GB",1: "10GB"' ), @@ -419,7 +455,8 @@ def from_quantized( Doc("The folder ot offload the model to."), ] = None, download_kwargs: Annotated[ - Dict, Doc("Used for configure download model"), + Dict, + Doc("Used for configure download model"), ] = None, **config_kwargs: Annotated[ Dict, @@ -455,11 +492,15 @@ def from_quantized( use_cpu_qbits = use_qbits or get_best_device() == "cpu" if use_cpu_qbits: if not qbits_available: - raise ImportError("Please install intel-extension-for-transformers with " - "`pip install intel-extension-for-transformers` for 'qbits' kernel!") + raise ImportError( + "Please install intel-extension-for-transformers with " + "`pip install intel-extension-for-transformers` for 'qbits' kernel!" + ) fuse_layers = False - logging.warn("Unsupport fuse_layers featrue for CPU device with QBits backend!") + logging.warn( + "Unsupport fuse_layers featrue for CPU device with QBits backend!" + ) # Prepare WQLinear layers, replace nn.Linear self._load_quantized_modules( self, @@ -547,7 +588,9 @@ def _load_config( elif isinstance(download_kwargs_ignore_patterns, list): ignore_patterns.extend(download_kwargs_ignore_patterns) - model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns, **download_kwargs) + model_path = snapshot_download( + model_path, ignore_patterns=ignore_patterns, **download_kwargs + ) if model_filename != "": model_weights_path = model_path + f"/{model_filename}" @@ -621,13 +664,17 @@ def _load_quantized_modules( q_linear_module = WQLinear_GEMVFast if use_qbits: - q_linear = q_linear_module.from_linear(module, - quant_config.w_bit, - quant_config.q_group_size, - True, - has_zero_points=quant_config.zero_point) + q_linear = q_linear_module.from_linear( + module, + quant_config.w_bit, + quant_config.q_group_size, + True, + has_zero_points=quant_config.zero_point, + ) else: - q_linear = q_linear_module.from_linear(module, quant_config.w_bit, quant_config.q_group_size, True) + q_linear = q_linear_module.from_linear( + module, quant_config.w_bit, quant_config.q_group_size, True + ) q_linear.to(next(layer.parameters()).device) set_op_by_name(layer, name, q_linear) diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index f2e36dc8..0118eac2 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -41,6 +41,10 @@ def __init__( modules_to_not_convert=None, export_compatible=False, apply_clip=True, + n_parallel_calib_samples=None, + max_calib_samples=128, + max_calib_seq_len=512, + max_chunk_memory=1024 * 1024 * 1024, ) -> None: self.awq_model = awq_model self.model = model @@ -55,10 +59,16 @@ def __init__( self.duo_scaling = duo_scaling self.export_compatible = export_compatible self.apply_clip = apply_clip + self.n_parallel_calib_samples = n_parallel_calib_samples + self.max_calib_samples = max_calib_samples + self.max_calib_seq_len = max_calib_seq_len + self.max_chunk_memory = max_chunk_memory self.modules_to_not_convert = ( modules_to_not_convert if modules_to_not_convert is not None else [] ) - self.modules, self.module_kwargs, self.inps = self.init_quant() + self.modules, self.module_kwargs, self.inps = self.init_quant( + n_samples=self.max_calib_samples, max_seq_len=self.max_calib_seq_len + ) def pseudo_quantize_tensor(self, w: torch.Tensor): org_w_shape = w.shape @@ -155,7 +165,7 @@ def quantize(self): ) scales_list = [ self._search_best_scale(self.modules[i], **layer) - for layer in module_config + for layer in tqdm(module_config, desc="Best Scales", leave=False) ] apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat) scales_list = append_str_prefix( @@ -207,7 +217,7 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): elif self.version == "marlin": q_linear_module = WQLinear_Marlin - + elif self.version == "gemv_fast": q_linear_module = WQLinear_GEMVFast @@ -228,6 +238,34 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): set_op_by_name(module, name, q_linear) clear_memory() + @torch.no_grad() + def _module_forward( + self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict + ) -> torch.Tensor: + if self.n_parallel_calib_samples is None: + # runs through all samples at once + module_output = module(x, **module_kwargs) + if isinstance(module_output, tuple): + module_output = module_output[0] + else: + # memory efficiently runs through all calibration samples + # but only n_parallel_calib_samples at a time + module_output = [] + partitioned_inputs = torch.split(x, self.n_parallel_calib_samples) + for x_partial in tqdm( + partitioned_inputs, desc="Module forward", leave=False + ): + partial_output = module(x_partial, **module_kwargs) + + if isinstance(partial_output, tuple): + partial_output = partial_output[0] + + module_output.append(partial_output.cpu()) + + module_output = torch.cat(module_output, dim=0) + + return module_output + @torch.no_grad() def _search_best_scale( self, @@ -254,7 +292,7 @@ def _search_best_scale( org_shape = weight.shape # The weights are reshaped to be organised by quantization group weight = weight.view(-1, self.group_size) - # Calculates the relative magnitude of the weights within each of the quantization groups, + # Calculates the relative magnitude of the weights within each of the quantization groups, # and rescales each group individually so that each group has weights on a 0-1 scale. w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6) # Resizes the rescaled weight matrix back up to its original dimensions @@ -263,16 +301,32 @@ def _search_best_scale( w_mean = w_scale.mean(0) clear_memory(weight) - # [STEP 2]: Compute per-channel mean of the input activation - x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0) + # [STEP 2]: Compute per-channel mean of the input activation with chunking + # move inp to cpu to avoid memory leak + inp_flat = inp.cpu().abs().view(-1, inp.shape[-1]) + num_elements = inp_flat.size(0) + num_channels = inp_flat.size(1) + element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32 + + # Calculate chunk size dynamically based on max_chunk_memory + chunk_size = int(self.max_chunk_memory // (element_size_bytes * num_channels)) + chunk_size = min(chunk_size, num_elements) + + # Use float32 for sum calculation + x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device) + + for i in range(0, num_elements, chunk_size): + end = min(i + chunk_size, num_elements) + chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0) + x_sum += chunk_sum.to(inp.device) + + x_mean = (x_sum / num_elements).to(inp.dtype) + clear_memory(x_sum) # [STEP 3]: Compute output of module with torch.no_grad(): module_kwargs = self._sanitize_kwargs(kwargs, module2inspect) - - fp16_output = module2inspect(inp, **module_kwargs) - if isinstance(fp16_output, tuple): - fp16_output = fp16_output[0] + fp16_output = self._module_forward(inp, module2inspect, module_kwargs) # [STEP 4]: Compute loss best_scales = self._compute_best_scale( @@ -287,13 +341,13 @@ def _search_best_scale( def _compute_best_scale( self, - x, - w_mean, - x_mean, - module2inspect, + x: torch.Tensor, + w_mean: torch.Tensor, + x_mean: torch.Tensor, + module2inspect: torch.nn.Module, linears2scale: List[nn.Linear], - fp16_output, - kwargs={}, + fp16_output: torch.Tensor, + kwargs: Dict={}, ): """ Compute loss and select best scales @@ -316,41 +370,43 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) - for ratio in range(n_grid): - # create new scales - ratio = ratio / n_grid - - # NOTE: s^-1 * x is fused here, according to paper - if self.duo_scaling: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - scales_view = scales.view(1, -1).to(device) - - # Q(W * s) - for fc in linears2scale: - fc.weight.mul_(scales_view) - fc.weight.data = ( - self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view - ) - - # W * X - int_w_output = module2inspect(x, **kwargs) - if isinstance(int_w_output, tuple): - int_w_output = int_w_output[0] - - # compute mean squared error (L2 norm) - loss = ( - (fp16_output - int_w_output).float().pow(2).mean().item() - ) # NOTE: float prevents overflow + with tqdm(range(n_grid), desc="Grid Search", leave=False) as pbar: + for ratio in pbar: + # create new scales + ratio = ratio / n_grid - history.append(loss) - if loss < best_error: - best_error = loss - best_ratio = ratio - best_scales = scales.clone() - module2inspect.load_state_dict(org_sd) + # NOTE: s^-1 * x is fused here, according to paper + if self.duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + scales_view = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for fc in linears2scale: + fc.weight.mul_(scales_view) + fc.weight.data = ( + self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view + ) + + # W * X + int_w_output = self._module_forward(x, module2inspect, kwargs) + + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_output, int_w_output, device) + + history.append(loss) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scales = scales.clone() + module2inspect.load_state_dict(org_sd) + pbar.set_description(f"Grid Search (Best: {best_ratio})") if best_ratio == -1: logging.debug(history) @@ -360,12 +416,51 @@ def _compute_best_scale( return best_scales.detach().cpu() + @torch.no_grad() + def _compute_loss( + self, + fp16_output: torch.Tensor, + int_w_output: torch.Tensor, + device: torch.device, + ): + loss = 0.0 + fp16_output_flat = fp16_output.view(-1) + int_w_output_flat = int_w_output.view(-1) + num_elements = fp16_output_flat.size(0) + element_size_bytes = fp16_output.element_size() + + # Calculate chunk size dynamically based on max_chunk_memory + # Divide the max_chunk_memory by twice the element size + chunk_size = self.max_chunk_memory // (element_size_bytes * 2) + chunk_size = min(chunk_size, num_elements) + + # Split the computation into chunks + fp16_chunks = torch.split(fp16_output_flat, chunk_size) + int_w_chunks = torch.split(int_w_output_flat, chunk_size) + + # Compute the loss for each chunk + with tqdm( + zip(fp16_chunks, int_w_chunks), + total=len(fp16_chunks), + desc="Computing Loss", + leave=False, + ) as pbar: + for fp16_chunk, int_w_chunk in pbar: + chunk_loss = (fp16_chunk.to(device) - int_w_chunk.to(device)).float().pow(2).sum().item() + loss += chunk_loss + pbar.set_description(f"Computing Loss (loss: {loss:.2f})") + + # Normalize the loss by the total number of elements + loss /= num_elements + + return loss + @torch.no_grad() def _search_best_clip(self, layer, named_linears, input_feat): clip_list = [] avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] - for name in named_linears: + for name in tqdm(named_linears, desc="Computing Best Clip", leave=False): # due to qk bmm, it is hard to clip precisely if any([_ in name for _ in avoid_clipping]): continue @@ -436,13 +531,13 @@ def _compute_best_clip( return best_max_val.squeeze(1) - def init_quant(self, n_samples=128, seqlen=512): + def init_quant(self, n_samples=128, max_seq_len=512): modules = self.awq_model.get_model_layers(self.model) samples = get_calib_dataset( data=self.calib_data, tokenizer=self.tokenizer, n_samples=n_samples, - block_size=seqlen, + max_seq_len=max_seq_len, split=self.split, text_column=self.text_column, ) @@ -542,7 +637,7 @@ def cache_input_hook(m, x, y, name, feat_dict): # Useful for trust_remote_code models. module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer) - self.inps = layer(self.inps, **module_kwargs)[0] + self.inps = self._module_forward(self.inps, layer, module_kwargs) for h in handles: h.remove() # now solve for scaling and clipping diff --git a/awq/utils/calib_data.py b/awq/utils/calib_data.py index 2408cf3f..25a523ea 100644 --- a/awq/utils/calib_data.py +++ b/awq/utils/calib_data.py @@ -7,8 +7,8 @@ def get_calib_dataset( data: Union[str, List[str], List[List[int]]] = "pileval", tokenizer=None, - n_samples=512, - block_size=512, + n_samples=128, + max_seq_len=512, split="train", text_column="text", ): @@ -47,7 +47,7 @@ def get_calib_dataset( line = data[text_column] line = line.strip() line_encoded = tokenizer.encode(line) - if len(line_encoded) > 512: + if len(line_encoded) > max_seq_len: continue sample = torch.tensor([line_encoded]) if sample.numel() == 0: @@ -56,10 +56,10 @@ def get_calib_dataset( n_run += 1 if n_run == n_samples: break - # now concatenate all samples and split according to block size + # now concatenate all samples and split according to max sequence length cat_samples = torch.cat(samples, dim=1) - n_split = cat_samples.shape[1] // block_size + n_split = cat_samples.shape[1] // max_seq_len logging.debug(f" * Split into {n_split} blocks") return [ - cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split) + cat_samples[:, i * max_seq_len : (i + 1) * max_seq_len] for i in range(n_split) ] diff --git a/docs/examples.md b/docs/examples.md index 5e3cd580..de5ac4c3 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -78,6 +78,109 @@ tokenizer.save_pretrained(quant_path) print(f'Model is quantized and saved at "{quant_path}"') ``` +#### Long-context: Optimizing quantization + +For this example, we will use HuggingFaceTB/cosmopedia-100k as it's a high-quality dataset and +we can filter directly on the number of tokens. We will use Qwen2 7B, one of the newer supported +models in AutoAWQ which is high-performing. The following example ran smoothly on a machine with +an RTX 4090 24 GB VRAM with 107 GB system RAM. + +NOTE: Adjusting `n_parallel_calib_samples`, `max_calib_samples`, and `max_calib_seq_len` will help +avoid OOM when customizing your dataset. + +- The AWQ algorithm is incredibly sample efficient, so `max_calib_samples` of 128-256 should be +sufficient to quantize a model. A higher number of samples may not be possible without significant +memory available or without further optimizing AWQ with a PR for disk offload. +- When `n_parallel_calib_samples` is set to an integer, we offload to system RAM to save GPU VRAM. +This may cause OOM on your system if you have little memory available; we are looking to optimize +this further in future versions. + +```python +from datasets import load_dataset +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'Qwen/Qwen2-7B-Instruct' +quant_path = 'qwen2-7b-awq' +quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + +# Load model +model = AutoAWQForCausalLM.from_pretrained( + model_path, **{"low_cpu_mem_usage": True, "use_cache": False} +) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +def load_cosmopedia(): + data = load_dataset('HuggingFaceTB/cosmopedia-100k', split="train") + data = data.filter(lambda x: x["text_token_length"] >= 2048) + + return [text for text in data["text"]] + +# Quantize +model.quantize( + tokenizer, + quant_config=quant_config, + calib_data=load_cosmopedia(), + n_parallel_calib_samples=32, + max_calib_samples=128, + max_calib_seq_len=4096 +) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') +``` + +#### Coding models + +For this example, we will use deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct as it's an excellent coding model. + +```python +from tqdm import tqdm +from datasets import load_dataset +from awq import AutoAWQForCausalLM +from transformers import AutoTokenizer + +model_path = 'deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct' +quant_path = 'deepseek-coder-v2-lite-instruct-awq' +quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" } + +# Load model +model = AutoAWQForCausalLM.from_pretrained( + model_path, **{"low_cpu_mem_usage": True, "use_cache": False} +) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +def load_openhermes_coding(): + data = load_dataset("alvarobartt/openhermes-preferences-coding", split="train") + + samples = [] + for sample in data: + responses = [f'{response["role"]}: {response["content"]}' for response in sample["chosen"]] + samples.append("\n".join(responses)) + + return samples + +# Quantize +model.quantize( + tokenizer, + quant_config=quant_config, + calib_data=load_openhermes_coding(), + # MODIFY these parameters if need be: + # n_parallel_calib_samples=32, + # max_calib_samples=128, + # max_calib_seq_len=4096 +) + +# Save quantized model +model.save_quantized(quant_path) +tokenizer.save_pretrained(quant_path) + +print(f'Model is quantized and saved at "{quant_path}"') +``` + ### GGUF Export This computes AWQ scales and appliesthem to the model without running real quantization.