Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _setup_packages() -> List:
)

def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"]
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict", "loguru"]

def _setup_extras() -> Dict:
return {
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
# flake8: noqa
from .base import *
from .dense import *
from .format import *
from .sparse_24_bitmask import *
from .sparse_bitmask import *
151 changes: 151 additions & 0 deletions src/compressed_tensors/config/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union

import torch
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from compressed_tensors.quantization.utils import is_module_quantized
from loguru import logger


__all__ = ["infer_and_set_per_module_quantization_format"]


def _get_quant_compression_format(
input_args: QuantizationArgs,
weight_args: QuantizationArgs,
sparsity_structure: Optional[str] = None,
) -> CompressionFormat:
"""
Using the weight and input quantization args as well as an optional
sparsity structure, determine the compression format that should be
applied to a given module

:param input_args: input quantization parameters
:param weight_args: weight quantization parameters
:param sparsity_structure: optional (global) modle sparsity
structure
:return CompresssionFormat for the module
"""
is_24_structure = (
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
)
is_weight_only = weight_args is not None and input_args is None

# w4a16, w4a4, fp4
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
if weight_args.strategy in (
QuantizationStrategy.TENSOR_GROUP.value,
QuantizationStrategy.CHANNEL.value,
QuantizationStrategy.GROUP.value,
):
return CompressionFormat.nvfp4_pack_quantized
else:
if is_weight_only:
return CompressionFormat.naive_quantized
return CompressionFormat.float_quantized

if is_weight_only: # w4a16 and w8a16, int
is_valid_pack = (
weight_args.num_bits in [4, 8]
and weight_args.type == QuantizationType.INT.value
)
if not is_valid_pack: # packing only valid for int4 and int 8
return CompressionFormat.naive_quantized
if is_24_structure:
if (
weight_args.strategy is not QuantizationStrategy.CHANNEL.value
and weight_args.strategy is not QuantizationStrategy.GROUP.value
):
# marlin24 kernel only applicable for channel/group quantization
return CompressionFormat.pack_quantized
return CompressionFormat.marlin_24
return CompressionFormat.pack_quantized

else: # w8a8 float and int
if (
weight_args.type == QuantizationType.FLOAT.value
and weight_args.num_bits == 8
):
return CompressionFormat.float_quantized
if weight_args.type == QuantizationType.INT.value:
return CompressionFormat.int_quantized

return CompressionFormat.naive_quantized


def set_per_module_format(
module: torch.nn.Module, sparsity_structure: Optional[str] = None
):
"""
Determine and set the per module quantization format given quantization args
and sparsity structure.

:param module: module which has its quantization inferred
:param sparisty_structure: optional sparsity applied to the module

"""
weight_scheme = module.quantization_scheme.weights
input_scheme = module.quantization_scheme.input_activations
if weight_scheme is None:
return # no weight quant - nothing to compress
compression_format = _get_quant_compression_format(
input_scheme, weight_scheme, sparsity_structure
)

# If set, we check if it matches our inferred one
if module.quantization_scheme.format is not None:
# If it does not, warn the user
if module.quantization_scheme.format != compression_format.value:
logger.warning(
"The provided format for the module does not match the "
"inferred format. Compression may fail "
)
else:
# If not set, we set ours
module.quantization_scheme.format = compression_format.value


def infer_and_set_per_module_quantization_format(
model: torch.nn.Module,
sparsity_structure: Optional[str] = None,
) -> Union[str, List[str]]:
"""
Infers the quantization format for a model based on its state and provided
compression arguments. Updates thhe quantization_scheme.format value
based on the inferred format. Returns the unique list of formats in the model
or None if empty list

For a summary of the formats, see `docs/guides/compression_formats.md`.

:param model: model to check for quantization
:param sparisty_structure: optional sparsity applied to the module
:return compression format appropriate for model
"""
unique_formats = []
for submodule in model.modules():
if is_module_quantized(submodule):
set_per_module_format(submodule, sparsity_structure)
if submodule.quantization_scheme.format not in unique_formats:
unique_formats.append(submodule.quantization_scheme.format)

if len(unique_formats) > 0:
return unique_formats
return CompressionFormat.dense.value
65 changes: 65 additions & 0 deletions tests/test_configs/test_infer_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import pytest
import torch
from compressed_tensors.config.format import (
infer_and_set_per_module_quantization_format,
)
from compressed_tensors.quantization import preset_name_to_scheme


@pytest.mark.parametrize(
"preset,sparsity_structure,expected_format",
[
["W8A8", "unstructured", "int-quantized"],
["W8A16", "unstructured", "pack-quantized"],
["W8A16", "2:4", "marlin-24"],
["W4A16", "unstructured", "pack-quantized"],
["W4A16", "2:4", "marlin-24"],
["FP8", "unstructured", "float-quantized"],
],
)
def test_infer_quant_format(preset, sparsity_structure, expected_format):
quant_scheme = preset_name_to_scheme(preset, targets=["Linear"])

dummy_model = torch.nn.Sequential(
OrderedDict(
[
("fc1", torch.nn.Linear(8, 16, bias=True)),
("fc2", torch.nn.Linear(16, 32, bias=True)),
(
"block1",
torch.nn.Sequential(
OrderedDict(
[
("fc1", torch.nn.Linear(32, 16, bias=True)),
("fc2", torch.nn.Linear(16, 8, bias=True)),
]
)
),
),
]
)
)

for _, module in dummy_model.named_modules():
module.quantization_scheme = quant_scheme

inferred_format = infer_and_set_per_module_quantization_format(
dummy_model, sparsity_structure=sparsity_structure
)
assert inferred_format[0] == expected_format