Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 125 additions & 104 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@


import logging
import math
import warnings
from typing import Optional
from typing import Optional, Tuple

import torch
from compressed_tensors.quantization import (
FP8_E4M3_DATA,
ActivationOrdering,
DynamicType,
KVCacheScaleType,
QuantizationArgs,
QuantizationMetadata,
Expand All @@ -32,7 +31,11 @@
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
from compressed_tensors.quantization.utils import (
is_fp4,
is_kv_cache_quant_scheme,
strategy_cdiv,
)
from compressed_tensors.utils import (
disable_hf_hook,
get_execution_device,
Expand All @@ -44,6 +47,7 @@
__all__ = [
"initialize_module_for_quantization",
"is_attention_module",
"initialize_qparams",
]


Expand All @@ -69,10 +73,8 @@ def initialize_module_for_quantization(
:param force_zero_point: whether to force initialization of a zero point for
symmetric quantization
"""
# TODO: don't initialize parameters when running decompression
scheme = scheme or getattr(module, "quantization_scheme", None)
if scheme is None:
# no scheme passed and layer not targeted for quantization - skip
return

QuantizationMetadata.clear_all_qparams(module)
Expand All @@ -82,38 +84,52 @@ def initialize_module_for_quantization(
_initialize_attn_scales(module)

else:
if not isinstance(module, torch.nn.Linear):
_LOGGER.warning(f"Attempting to quantize module of type {type(module)}")

# use weight to determine observed shapes and dtype
if hasattr(module, "weight"):
weight = module.weight
assert isinstance(weight, torch.Tensor)
else:
# Note that a weight is required for both weight and activation
# quantization in order to know the dtype of activation scales
_LOGGER.warning(
f"module type {type(module)} targeted for quantization but "
f"has no attribute weight, skipping quantization for {type(module)}"
)
return

if scheme.input_activations is not None:
_initialize_scale_zero_point(
initialize_qparams(
module,
"input",
scheme.input_activations,
observed_shape=weight.shape[-1:],
observed_dtype=weight.dtype,
force_zero_point=force_zero_point,
)

if scheme.weights is not None:
if hasattr(module, "weight"):
weight_shape = None
if isinstance(module, torch.nn.Linear):
weight_shape = module.weight.shape
_initialize_scale_zero_point(
module,
"weight",
scheme.weights,
weight_shape=weight_shape,
force_zero_point=force_zero_point,
)
else:
_LOGGER.warning(
f"module type {type(module)} targeted for weight quantization but "
"has no attribute weight, skipping weight quantization "
f"for {type(module)}"
)

if scheme.output_activations is not None:
if not is_kv_cache_quant_scheme(scheme):
_initialize_scale_zero_point(
module, "output", scheme.output_activations
)
initialize_qparams(
module,
"weight",
scheme.weights,
observed_shape=weight.shape,
observed_dtype=weight.dtype,
force_zero_point=force_zero_point,
)

output_is_kv_cache = is_kv_cache_quant_scheme(scheme)
if scheme.output_activations is not None and not output_is_kv_cache:
initialize_qparams(
module,
"output",
scheme.output_activations,
observed_shape=weight.shape[:-1],
observed_dtype=weight.dtype,
force_zero_point=force_zero_point,
)

module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED
Expand All @@ -132,22 +148,40 @@ def is_attention_module(module: Module):
)


def _initialize_scale_zero_point(
def initialize_qparams(
module: Module,
base_name: str,
quantization_args: QuantizationArgs,
weight_shape: Optional[torch.Size] = None,
observed_shape: Tuple[int],
observed_dtype: torch.dtype,
force_zero_point: bool = True,
):
if quantization_args.dynamic is True:
return
"""
Initialize quantization parameters for a given basename according to the passed
quantization args. The shape and dtype of the observed weight/activation must also
be provided.

Scales will always be initialized. Global scales are initialized depending on args.
Zero points will be initialized if not symmetric or if `force_zero_point` is True.

:param module: module to register qparams to
:param base_name: base name of qparams, for example "input", "weight", "k", "v"
:param quantization_args: arguments for quantization
:param observed_shape: last (right-most) known dimensions of the observed weight/act
:param observed_dtype: dtype of the observed weight/actt
:param force_zero_point: force the zero_point parameter to be initialized
"""
strategy = quantization_args.strategy
dynamic = quantization_args.dynamic
actorder = quantization_args.actorder
device = get_execution_device(module) # avoid performing intialization ops on cpu

# initialize on execution device to avoid performing quantized ops on cpu
device = get_execution_device(module)
# Skip all intialization for fully dynamic quantization
if dynamic is True:
return

# 1. Create global_scales for tensor_group - generates
# a per tensor scale
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
# 0. Create global scale for tensor-group quantization
if strategy == QuantizationStrategy.TENSOR_GROUP:
init_global_scale = Parameter(
torch.empty(1, dtype=torch.float32, device=device),
requires_grad=False,
Expand All @@ -156,56 +190,55 @@ def _initialize_scale_zero_point(
module, f"{base_name}_global_scale", init_global_scale
)

# 2. Infer expected scale/zero point shape
if quantization_args.strategy == QuantizationStrategy.TOKEN:
# Skip scale/zp initialization for locally dynamic quantization
if dynamic == DynamicType.LOCAL:
return

# 1. Infer expected scale/zp shape
if strategy == QuantizationStrategy.TENSOR:
expected_shape = (1,)

elif strategy == QuantizationStrategy.TOKEN:
expected_shape = (1, 1)

elif strategy == QuantizationStrategy.CHANNEL:
if len(observed_shape) < 2:
raise ValueError("Channel quant requires at least 2 observed dimensions")

expected_shape = (observed_shape[-2], 1)

elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
assert quantization_args.group_size is not None
if len(observed_shape) < 1:
raise ValueError("Group quant requires at least 1 observed dimension")

group_size = quantization_args.group_size
num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
expected_shape = (*observed_shape[:-1], num_groups)

# initialize activation ordering if applicable
if actorder == ActivationOrdering.GROUP:
init_g_idx = Parameter(
torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int),
requires_grad=False,
)
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)

elif strategy == QuantizationStrategy.BLOCK:
assert quantization_args.block_structure is not None
if len(observed_shape) < 2:
raise ValueError("Block quant requires at least 2 observed dimensions")

block_structure = quantization_args.block_structure
num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy)
num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
expected_shape = (num_rows, num_cols)

else:
expected_shape = 1

if base_name == "weight" and weight_shape is not None:
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
# (output_channels, 1) - only for weights
expected_shape = (weight_shape[0], 1)
elif quantization_args.strategy in (
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
):
# GROUP/TENSOR_GROUP for both weights and activations
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
expected_shape = (weight_shape[0], max(num_groups, 1))
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
# For block quantization, scale shape should match number of blocks - only
# for weights
if quantization_args.block_structure is None:
raise ValueError(
"Block quantization requires block_structure to be specified"
)
block_height, block_width = quantization_args.block_structure
rows, cols = weight_shape[-2], weight_shape[-1]
num_rows_blocks = math.ceil(rows / block_height)
num_cols_blocks = math.ceil(cols / block_width)

# Warn if dimensions don't divide evenly
if rows % block_height != 0 or cols % block_width != 0:
warnings.warn(
f"Block quantization: tensor shape {weight_shape} does not divide"
f"evenly by block structure {quantization_args.block_structure}. "
f"Some blocks will be incomplete which may affect quantization"
"quality.",
UserWarning,
)

expected_shape = (num_rows_blocks, num_cols_blocks)
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
warnings.warn(
f"BLOCK quantization not supported for {base_name} activations. "
f"Falling back to tensor-level quantization.",
UserWarning,
)
expected_shape = 1
assert False, f"Unknown strategy {strategy}"

# 3. Identify quantization scale and zp dtype
scale_dtype = module.weight.dtype
# 2. Identify quantization scale and zp dtype
scale_dtype = observed_dtype

if is_fp4(quantization_args=quantization_args):
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
Expand All @@ -221,14 +254,12 @@ def _initialize_scale_zero_point(
scale_dtype = torch.bfloat16
zp_dtype = quantization_args.pytorch_dtype()

# 4. Initializes empty scale, zero point, and g_idx parameters for the module
# do not init scales for quantzation_args.dynamic == DynamicType.local
if not quantization_args.dynamic:
init_scale = Parameter(
torch.empty(expected_shape, dtype=scale_dtype, device=device),
requires_grad=False,
)
register_offload_parameter(module, f"{base_name}_scale", init_scale)
# 3. Initializes scale/zp for the module
init_scale = Parameter(
torch.empty(expected_shape, dtype=scale_dtype, device=device),
requires_grad=False,
)
register_offload_parameter(module, f"{base_name}_scale", init_scale)

if force_zero_point or not quantization_args.symmetric:
init_zero_point = Parameter(
Expand All @@ -237,16 +268,6 @@ def _initialize_scale_zero_point(
)
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)

# only grouped activation ordering has g_idx
if quantization_args.actorder == ActivationOrdering.GROUP:
g_idx_shape = (weight_shape[1],)
g_idx_dtype = torch.int
init_g_idx = Parameter(
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
requires_grad=False,
)
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)


def _initialize_attn_scales(module: Module) -> None:
"""Initlaize k_scale, v_scale for self_attn"""
Expand Down
25 changes: 25 additions & 0 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.utils import deprecated
from loguru import logger
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module

Expand All @@ -47,6 +48,7 @@
"calculate_qparams",
"generate_gparam",
"is_fp4",
"strategy_cdiv",
]

# target the self_attn layer
Expand Down Expand Up @@ -461,3 +463,26 @@ def generate_gparam(
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
global_scale = scale_data.max * quant_data.max / max_val_pos
return global_scale.to(dtype).reshape([1])


def strategy_cdiv(
value: int,
divisor: int,
strategy: Optional[QuantizationStrategy],
strict: bool = False,
) -> int:
dividend = math.ceil(value / divisor)
if dividend * divisor != value:
message = (
f"{strategy} quantization strategy requires strict division of "
f"weight/activation size {value} and group/block size {divisor}. "
"consider reducing the group/block size or ignoring modules with "
f"weights not divisible by {divisor}"
)
if strict:
raise ValueError(message)

else:
logger.bind(log_once=True).warning(message)

return dividend