From 8ecaf4643677865330e48b95c7c5d015ba97c90b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:56:25 -0400 Subject: [PATCH 1/7] refactor Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 237 +++++++++--------- .../quantization/utils/helpers.py | 41 +++ 2 files changed, 163 insertions(+), 115 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 9f852c74f..2e250ce92 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,25 +14,27 @@ import logging -import math -import warnings -from typing import Optional +from enum import Enum +from typing import Optional, Tuple import torch -from compressed_tensors.quantization import ( +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) +from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, - KVCacheScaleType, + DynamicType, QuantizationArgs, - QuantizationMetadata, - QuantizationScheme, - QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, +from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils import ( + is_fp4, + is_kv_cache_quant_scheme, + strategy_cdiv, ) -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -44,23 +46,28 @@ __all__ = [ "initialize_module_for_quantization", "is_attention_module", + "KVCacheScaleType", ] _LOGGER = logging.getLogger(__name__) +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - Attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme. + attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme - Previously initialized scales and zero points will be removed from - module if they no longer apply to the scheme + apply to full model with `model.apply(initialize_module_for_quantization)` :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -69,51 +76,61 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ - # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return - QuantizationMetadata.clear_all_qparams(module) - if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) else: + if not isinstance(module, torch.nn.Linear): + _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") + + # use weight to determine observed shapes and dtype + if hasattr(module, "weight"): + weight = module.weight + assert isinstance(weight, torch.Tensor) + else: + # Note that a weight is required for both weight and activation + # quantization in order to know the dtype of activation scales + _LOGGER.warning( + f"module type {type(module)} targeted for quantization but " + f"has no attribute weight, skipping quantization for {type(module)}" + ) + return + if scheme.input_activations is not None: _initialize_scale_zero_point( module, "input", scheme.input_activations, + observed_shape=(1, weight.shape[-1]), + observed_dtype=weight.dtype, force_zero_point=force_zero_point, ) if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) - - if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations - ) + _initialize_scale_zero_point( + module, + "weight", + scheme.weights, + observed_shape=weight.shape, + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) + + output_is_kv_cache = is_kv_cache_quant_scheme(scheme) + if scheme.output_activations is not None and not output_is_kv_cache: + _initialize_scale_zero_point( + module, + "output", + scheme.output_activations, + observed_shape=weight.shape[:-1], + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -136,18 +153,21 @@ def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, + observed_shape: Tuple[int], + observed_dtype: torch.dtype, force_zero_point: bool = True, ): - if quantization_args.dynamic is True: - return + strategy = quantization_args.strategy + dynamic = quantization_args.dynamic + actorder = quantization_args.actorder + device = get_execution_device(module) # avoid performing intialization ops on cpu - # initialize on execution device to avoid performing quantized ops on cpu - device = get_execution_device(module) + # Skip all intialization for fully dynamic quantization + if dynamic is True: + return - # 1. Create global_scales for tensor_group - generates - # a per tensor scale - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + # 0. Create global scale for tensor-group quantization + if strategy == QuantizationStrategy.TENSOR_GROUP: init_global_scale = Parameter( torch.empty(1, dtype=torch.float32, device=device), requires_grad=False, @@ -156,56 +176,55 @@ def _initialize_scale_zero_point( module, f"{base_name}_global_scale", init_global_scale ) - # 2. Infer expected scale/zero point shape - if quantization_args.strategy == QuantizationStrategy.TOKEN: + # Skip scale/zp initialization for locally dynamic quantization + if dynamic == DynamicType.LOCAL: + return + + # 1. Infer expected scale/zp shape + if strategy == QuantizationStrategy.TENSOR: + expected_shape = (1,) + + elif strategy == QuantizationStrategy.TOKEN: expected_shape = (1, 1) + + elif strategy == QuantizationStrategy.CHANNEL: + if len(observed_shape) < 1: + raise ValueError("Channel quant requires at least 1 observed dimension") + + expected_shape = (observed_shape[-2], 1) + + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") + + group_size = quantization_args.group_size + num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) + expected_shape = (*observed_shape[:-1], num_groups) + + # initialize activation ordering if applicable + if actorder == ActivationOrdering.GROUP: + init_g_idx = Parameter( + torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + elif strategy == QuantizationStrategy.BLOCK: + assert quantization_args.block_structure is not None + if len(observed_shape) < 2: + raise ValueError("Block quant requires at least 2 observed dimensions") + + block_structure = quantization_args.block_structure + num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy) + num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) + expected_shape = (num_rows, num_cols) + else: - expected_shape = 1 - - if base_name == "weight" and weight_shape is not None: - if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) - only for weights - expected_shape = (weight_shape[0], 1) - elif quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - # GROUP/TENSOR_GROUP for both weights and activations - num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) - expected_shape = (weight_shape[0], max(num_groups, 1)) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only - # for weights - if quantization_args.block_structure is None: - raise ValueError( - "Block quantization requires block_structure to be specified" - ) - block_height, block_width = quantization_args.block_structure - rows, cols = weight_shape[-2], weight_shape[-1] - num_rows_blocks = math.ceil(rows / block_height) - num_cols_blocks = math.ceil(cols / block_width) - - # Warn if dimensions don't divide evenly - if rows % block_height != 0 or cols % block_width != 0: - warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide" - f"evenly by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization" - "quality.", - UserWarning, - ) - - expected_shape = (num_rows_blocks, num_cols_blocks) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - warnings.warn( - f"BLOCK quantization not supported for {base_name} activations. " - f"Falling back to tensor-level quantization.", - UserWarning, - ) - expected_shape = 1 + assert False, f"Unknown strategy {strategy}" - # 3. Identify quantization scale and zp dtype - scale_dtype = module.weight.dtype + # 2. Identify quantization scale and zp dtype + scale_dtype = observed_dtype if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype @@ -221,14 +240,12 @@ def _initialize_scale_zero_point( scale_dtype = torch.bfloat16 zp_dtype = quantization_args.pytorch_dtype() - # 4. Initializes empty scale, zero point, and g_idx parameters for the module - # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_scale", init_scale) + # 3. Initializes scale/zp for the module + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: init_zero_point = Parameter( @@ -237,16 +254,6 @@ def _initialize_scale_zero_point( ) register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) - # only grouped activation ordering has g_idx - if quantization_args.actorder == ActivationOrdering.GROUP: - g_idx_shape = (weight_shape[1],) - g_idx_dtype = torch.int - init_g_idx = Parameter( - torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) - def _initialize_attn_scales(module: Module) -> None: """Initlaize k_scale, v_scale for self_attn""" diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 1b6937d47..4821f51c4 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -27,11 +27,13 @@ ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.utils import deprecated +from loguru import logger from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module __all__ = [ + "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -47,6 +49,7 @@ "calculate_qparams", "generate_gparam", "is_fp4", + "strategy_cdiv", ] # target the self_attn layer @@ -233,6 +236,21 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max +def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa + """ + Checks the quantization status of a model. Assumes all modules in the model have + the same status, so only the first quantized model is checked. + + :param model: model to check quantization status for + :return: quantization status if the model is quantized, otherwise None + """ + for module in model.modules(): + status = getattr(module, "quantization_status", None) + if status is not None: + return status + return None + + def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization @@ -461,3 +479,26 @@ def generate_gparam( max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos return global_scale.to(dtype).reshape([1]) + + +def strategy_cdiv( + value: int, + divisor: int, + strategy: Optional[QuantizationStrategy], + strict: bool = False, +) -> int: + dividend = math.ceil(value / divisor) + if dividend * divisor != value: + message = ( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {value} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with " + f"weights not divisible by {divisor}" + ) + if strict: + raise ValueError(message) + + else: + logger.bind(log_once=True).warning(message) + + return dividend From 92705f2fbede07a9356f00e217db04b736a95709 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 09:59:55 -0400 Subject: [PATCH 2/7] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2e250ce92..c026c85ac 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,22 +14,23 @@ import logging -from enum import Enum from typing import Optional, Tuple import torch -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from compressed_tensors.quantization.quant_args import ( +from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, + KVCacheScaleType, QuantizationArgs, + QuantizationMetadata, + QuantizationScheme, + QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, @@ -46,28 +47,23 @@ __all__ = [ "initialize_module_for_quantization", "is_attention_module", - "KVCacheScaleType", ] _LOGGER = logging.getLogger(__name__) -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme + Attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme. - apply to full model with `model.apply(initialize_module_for_quantization)` + Previously initialized scales and zero points will be removed from + module if they no longer apply to the scheme :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -80,6 +76,8 @@ def initialize_module_for_quantization( if scheme is None: return + QuantizationMetadata.clear_all_qparams(module) + if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) From 70299f37c8b168aa6e18c38e9d0eac9813daa7a6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 10:01:21 -0400 Subject: [PATCH 3/7] reduce diff Signed-off-by: Kyle Sayers --- .../quantization/utils/helpers.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 4821f51c4..26c09b41a 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -33,7 +33,6 @@ __all__ = [ - "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -236,21 +235,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max -def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa - """ - Checks the quantization status of a model. Assumes all modules in the model have - the same status, so only the first quantized model is checked. - - :param model: model to check quantization status for - :return: quantization status if the model is quantized, otherwise None - """ - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization From 42ee086449b57b6f0e2df9c4c80a47b9f29d42d0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 10:11:31 -0400 Subject: [PATCH 4/7] initialize_qparams Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index c026c85ac..5bf59afb6 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -47,6 +47,7 @@ __all__ = [ "initialize_module_for_quantization", "is_attention_module", + "initialize_qparams", ] @@ -100,7 +101,7 @@ def initialize_module_for_quantization( return if scheme.input_activations is not None: - _initialize_scale_zero_point( + initialize_qparams( module, "input", scheme.input_activations, @@ -110,7 +111,7 @@ def initialize_module_for_quantization( ) if scheme.weights is not None: - _initialize_scale_zero_point( + initialize_qparams( module, "weight", scheme.weights, @@ -121,7 +122,7 @@ def initialize_module_for_quantization( output_is_kv_cache = is_kv_cache_quant_scheme(scheme) if scheme.output_activations is not None and not output_is_kv_cache: - _initialize_scale_zero_point( + initialize_qparams( module, "output", scheme.output_activations, @@ -147,7 +148,7 @@ def is_attention_module(module: Module): ) -def _initialize_scale_zero_point( +def initialize_qparams( module: Module, base_name: str, quantization_args: QuantizationArgs, @@ -155,6 +156,21 @@ def _initialize_scale_zero_point( observed_dtype: torch.dtype, force_zero_point: bool = True, ): + """ + Initialize quantization parameters for a given basename according to the passed + quantization args. The shape and dtype of the observed weight/activation must also + be provided. + + Scales will always be initialized. Global scales are initialized depending on args. + Zero points will be initialized if not symmetric or if `force_zero_point` is True. + + :param module: module to register qparams to + :param base_name: base name of qparams, for example "input", "weight", "k", "v" + :param quantization_args: arguments for quantization + :param observed_shape: last (right-most) known dimensions of the observed weight/act + :param observed_dtype: dtype of the observed weight/actt + :param force_zero_point: force the zero_point parameter to be initialized + """ strategy = quantization_args.strategy dynamic = quantization_args.dynamic actorder = quantization_args.actorder From 72560d4dfbe4e1016a54a757e6c6a1dc13ba9788 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 30 Sep 2025 17:27:27 -0400 Subject: [PATCH 5/7] simplify activation shape Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5bf59afb6..8dbda2787 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -105,7 +105,7 @@ def initialize_module_for_quantization( module, "input", scheme.input_activations, - observed_shape=(1, weight.shape[-1]), + observed_shape=weight.shape[-1:], observed_dtype=weight.dtype, force_zero_point=force_zero_point, ) From 8053b5178e500b8eab60b0bef9e560b01a040306 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:20:01 -0400 Subject: [PATCH 6/7] increase num of required observed dims Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/initialize.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8dbda2787..390b174a9 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -202,8 +202,8 @@ def initialize_qparams( expected_shape = (1, 1) elif strategy == QuantizationStrategy.CHANNEL: - if len(observed_shape) < 1: - raise ValueError("Channel quant requires at least 1 observed dimension") + if len(observed_shape) < 2: + raise ValueError("Channel quant requires at least 2 observed dimensions") expected_shape = (observed_shape[-2], 1) @@ -234,6 +234,12 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) + elif strategy == QuantizationStrategy.ATTN_HEAD: + if len(observed_shape) < 2: + raise ValueError("Attention quant requires at least 2 observed dimensions") + + expected_shape = (observed_shape[-2], 1) + else: assert False, f"Unknown strategy {strategy}" From 1ef32e3c030e1ffea9cb7ed272b9813c1ba12f49 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 6 Oct 2025 18:30:24 -0400 Subject: [PATCH 7/7] remove attention head Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 390b174a9..4b896d37d 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -234,12 +234,6 @@ def initialize_qparams( num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) expected_shape = (num_rows, num_cols) - elif strategy == QuantizationStrategy.ATTN_HEAD: - if len(observed_shape) < 2: - raise ValueError("Attention quant requires at least 2 observed dimensions") - - expected_shape = (observed_shape[-2], 1) - else: assert False, f"Unknown strategy {strategy}"