diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 9f852c74f..4b896d37d 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,14 +14,13 @@ import logging -import math -import warnings -from typing import Optional +from typing import Optional, Tuple import torch from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, + DynamicType, KVCacheScaleType, QuantizationArgs, QuantizationMetadata, @@ -32,7 +31,11 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme +from compressed_tensors.quantization.utils import ( + is_fp4, + is_kv_cache_quant_scheme, + strategy_cdiv, +) from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -44,6 +47,7 @@ __all__ = [ "initialize_module_for_quantization", "is_attention_module", + "initialize_qparams", ] @@ -69,10 +73,8 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ - # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return QuantizationMetadata.clear_all_qparams(module) @@ -82,38 +84,52 @@ def initialize_module_for_quantization( _initialize_attn_scales(module) else: + if not isinstance(module, torch.nn.Linear): + _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") + + # use weight to determine observed shapes and dtype + if hasattr(module, "weight"): + weight = module.weight + assert isinstance(weight, torch.Tensor) + else: + # Note that a weight is required for both weight and activation + # quantization in order to know the dtype of activation scales + _LOGGER.warning( + f"module type {type(module)} targeted for quantization but " + f"has no attribute weight, skipping quantization for {type(module)}" + ) + return + if scheme.input_activations is not None: - _initialize_scale_zero_point( + initialize_qparams( module, "input", scheme.input_activations, + observed_shape=weight.shape[-1:], + observed_dtype=weight.dtype, force_zero_point=force_zero_point, ) if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) - - if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations - ) + initialize_qparams( + module, + "weight", + scheme.weights, + observed_shape=weight.shape, + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) + + output_is_kv_cache = is_kv_cache_quant_scheme(scheme) + if scheme.output_activations is not None and not output_is_kv_cache: + initialize_qparams( + module, + "output", + scheme.output_activations, + observed_shape=weight.shape[:-1], + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -132,22 +148,40 @@ def is_attention_module(module: Module): ) -def _initialize_scale_zero_point( +def initialize_qparams( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, + observed_shape: Tuple[int], + observed_dtype: torch.dtype, force_zero_point: bool = True, ): - if quantization_args.dynamic is True: - return + """ + Initialize quantization parameters for a given basename according to the passed + quantization args. The shape and dtype of the observed weight/activation must also + be provided. + + Scales will always be initialized. Global scales are initialized depending on args. + Zero points will be initialized if not symmetric or if `force_zero_point` is True. + + :param module: module to register qparams to + :param base_name: base name of qparams, for example "input", "weight", "k", "v" + :param quantization_args: arguments for quantization + :param observed_shape: last (right-most) known dimensions of the observed weight/act + :param observed_dtype: dtype of the observed weight/actt + :param force_zero_point: force the zero_point parameter to be initialized + """ + strategy = quantization_args.strategy + dynamic = quantization_args.dynamic + actorder = quantization_args.actorder + device = get_execution_device(module) # avoid performing intialization ops on cpu - # initialize on execution device to avoid performing quantized ops on cpu - device = get_execution_device(module) + # Skip all intialization for fully dynamic quantization + if dynamic is True: + return - # 1. Create global_scales for tensor_group - generates - # a per tensor scale - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + # 0. Create global scale for tensor-group quantization + if strategy == QuantizationStrategy.TENSOR_GROUP: init_global_scale = Parameter( torch.empty(1, dtype=torch.float32, device=device), requires_grad=False, @@ -156,56 +190,55 @@ def _initialize_scale_zero_point( module, f"{base_name}_global_scale", init_global_scale ) - # 2. Infer expected scale/zero point shape - if quantization_args.strategy == QuantizationStrategy.TOKEN: + # Skip scale/zp initialization for locally dynamic quantization + if dynamic == DynamicType.LOCAL: + return + + # 1. Infer expected scale/zp shape + if strategy == QuantizationStrategy.TENSOR: + expected_shape = (1,) + + elif strategy == QuantizationStrategy.TOKEN: expected_shape = (1, 1) + + elif strategy == QuantizationStrategy.CHANNEL: + if len(observed_shape) < 2: + raise ValueError("Channel quant requires at least 2 observed dimensions") + + expected_shape = (observed_shape[-2], 1) + + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") + + group_size = quantization_args.group_size + num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) + expected_shape = (*observed_shape[:-1], num_groups) + + # initialize activation ordering if applicable + if actorder == ActivationOrdering.GROUP: + init_g_idx = Parameter( + torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + elif strategy == QuantizationStrategy.BLOCK: + assert quantization_args.block_structure is not None + if len(observed_shape) < 2: + raise ValueError("Block quant requires at least 2 observed dimensions") + + block_structure = quantization_args.block_structure + num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy) + num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) + expected_shape = (num_rows, num_cols) + else: - expected_shape = 1 - - if base_name == "weight" and weight_shape is not None: - if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) - only for weights - expected_shape = (weight_shape[0], 1) - elif quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - # GROUP/TENSOR_GROUP for both weights and activations - num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) - expected_shape = (weight_shape[0], max(num_groups, 1)) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only - # for weights - if quantization_args.block_structure is None: - raise ValueError( - "Block quantization requires block_structure to be specified" - ) - block_height, block_width = quantization_args.block_structure - rows, cols = weight_shape[-2], weight_shape[-1] - num_rows_blocks = math.ceil(rows / block_height) - num_cols_blocks = math.ceil(cols / block_width) - - # Warn if dimensions don't divide evenly - if rows % block_height != 0 or cols % block_width != 0: - warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide" - f"evenly by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization" - "quality.", - UserWarning, - ) - - expected_shape = (num_rows_blocks, num_cols_blocks) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - warnings.warn( - f"BLOCK quantization not supported for {base_name} activations. " - f"Falling back to tensor-level quantization.", - UserWarning, - ) - expected_shape = 1 + assert False, f"Unknown strategy {strategy}" - # 3. Identify quantization scale and zp dtype - scale_dtype = module.weight.dtype + # 2. Identify quantization scale and zp dtype + scale_dtype = observed_dtype if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype @@ -221,14 +254,12 @@ def _initialize_scale_zero_point( scale_dtype = torch.bfloat16 zp_dtype = quantization_args.pytorch_dtype() - # 4. Initializes empty scale, zero point, and g_idx parameters for the module - # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_scale", init_scale) + # 3. Initializes scale/zp for the module + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: init_zero_point = Parameter( @@ -237,16 +268,6 @@ def _initialize_scale_zero_point( ) register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) - # only grouped activation ordering has g_idx - if quantization_args.actorder == ActivationOrdering.GROUP: - g_idx_shape = (weight_shape[1],) - g_idx_dtype = torch.int - init_g_idx = Parameter( - torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) - def _initialize_attn_scales(module: Module) -> None: """Initlaize k_scale, v_scale for self_attn""" diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 1b6937d47..26c09b41a 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -27,6 +27,7 @@ ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.utils import deprecated +from loguru import logger from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module @@ -47,6 +48,7 @@ "calculate_qparams", "generate_gparam", "is_fp4", + "strategy_cdiv", ] # target the self_attn layer @@ -461,3 +463,26 @@ def generate_gparam( max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos return global_scale.to(dtype).reshape([1]) + + +def strategy_cdiv( + value: int, + divisor: int, + strategy: Optional[QuantizationStrategy], + strict: bool = False, +) -> int: + dividend = math.ceil(value / divisor) + if dividend * divisor != value: + message = ( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {value} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with " + f"weights not divisible by {divisor}" + ) + if strict: + raise ValueError(message) + + else: + logger.bind(log_once=True).warning(message) + + return dividend