-
Notifications
You must be signed in to change notification settings - Fork 281
add mxfp8 qat #2299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lkk12014402
wants to merge
16
commits into
intel:master
Choose a base branch
from
lkk12014402:qat_mxfp8
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add mxfp8 qat #2299
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b34fb32
add mxfp8 qat code, mxfp8fwd-bf16bwd.
lkk12014402 7f99561
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b6d74ae
fix comments.
lkk12014402 c9a0026
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1651d71
fix code style.
lkk12014402 fcf4b86
add unit tests.
lkk12014402 089c247
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6c0621d
update `prepare_qat` entry.
lkk12014402 a1f8c3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fbe0918
update `prepare_qat` code style to align with torchao.
lkk12014402 4d7508f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6d89e55
add qat test ut.
lkk12014402 0551717
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ece99c3
fix ut.
lkk12014402 1addd32
update qat ut assert.
lkk12014402 221a496
Merge branch 'master' into qat_mxfp8
lkk12014402 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# pylint:disable=import-error | ||
"""QAT (Quantization Aware Tuning).""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Quantized Linear.""" | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .tensor_quantizer import TensorQuantizer | ||
|
||
|
||
class QuantLinear(nn.Module): | ||
"""Quantized version of nn.Linear.""" | ||
|
||
def forward(self, input: torch.Tensor): | ||
"""Add weight/input/output of quantization for the original forward method.""" | ||
qw = self.weight_quantizer(self.weight) | ||
qi = self.input_quantizer(input) | ||
out = F.linear(qi, qw, self.bias) | ||
out = self.output_quantizer(out) | ||
return out | ||
|
||
def _setup(self, quant_cfg): | ||
"""Init quantizer.""" | ||
self.weight_quantizer = TensorQuantizer( | ||
data_type=quant_cfg.data_type, | ||
block_size=quant_cfg.group_size, | ||
bits=quant_cfg.bits, | ||
sym=quant_cfg.sym, | ||
if_quant=True, | ||
learn_exponent=False, | ||
) | ||
self.input_quantizer = TensorQuantizer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM, just one question, how do we set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. disable the input quantizer call |
||
data_type=quant_cfg.act_data_type, | ||
block_size=quant_cfg.act_group_size, | ||
bits=quant_cfg.act_bits, | ||
sym=quant_cfg.act_sym, | ||
if_quant=True, | ||
learn_exponent=False, | ||
) | ||
self.output_quantizer = TensorQuantizer( | ||
data_type=quant_cfg.act_data_type, | ||
block_size=quant_cfg.act_group_size, | ||
bits=quant_cfg.act_bits, | ||
sym=quant_cfg.act_sym, | ||
if_quant=False, | ||
) | ||
# Currently don't quant output | ||
self.output_quantizer.disable() | ||
|
||
# TODO: remove | ||
self.original_weight_dtype = None if self.weight is None else self.weight.dtype | ||
|
||
def extra_repr(self) -> str: | ||
"""Generate extra_repr making sure import keys exist in self.__dict__.""" | ||
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" | ||
|
||
def __repr__(self): | ||
"""Overriding the __repr__ method, makes the output more concise and meaningful.""" | ||
return ( | ||
f"QuantLinear(\n" | ||
f" in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}\n" | ||
f" (input_quantizer): {self.input_quantizer}\n" | ||
f" (output_quantizer): {self.output_quantizer}\n" | ||
f" (weight_quantizer): {self.weight_quantizer}\n" | ||
f")" | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
# | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Utils for quantization.""" | ||
|
||
import types | ||
from typing import Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .quant_linear import QuantLinear | ||
from .tensor_quantizer import TensorQuantizer | ||
|
||
|
||
def convert(module: nn.Module, quant_cfg=None, quant_module=None): | ||
"""Convert the model to a quantized one with quant config.""" | ||
|
||
# update class | ||
original_cls = type(module) | ||
module.__class__ = quant_module | ||
module.forward = types.MethodType(quant_module.forward, module) | ||
|
||
# setup quantizers | ||
module._setup(quant_cfg) | ||
|
||
return module | ||
|
||
|
||
def replace_with_quant_linear(model, quant_cfg=None): | ||
yiliu30 marked this conversation as resolved.
Show resolved
Hide resolved
yiliu30 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Recursively replace the module with quantized module.""" | ||
|
||
# TODO: support more modules, like kv. | ||
for name, child in model.named_children(): | ||
if isinstance(child, nn.Linear): | ||
if "lm_head" in name: | ||
continue | ||
# REPLACE on the parent (model), not on child | ||
quantized = convert(child, quant_cfg, QuantLinear) | ||
setattr(model, name, quantized) | ||
|
||
# now recurse into whichever module is now at `model.name` | ||
replace_with_quant_linear(getattr(model, name), quant_cfg=quant_cfg) | ||
|
||
return model | ||
|
||
|
||
def get_quant_config_with_scheme(scheme: str): | ||
"""Get quantization config.""" | ||
|
||
try: | ||
# use scheme definitions from AutoRound since we utilize the quantization functions now | ||
from auto_round.schemes import preset_name_to_scheme | ||
|
||
quant_cfg = preset_name_to_scheme(scheme) | ||
return quant_cfg | ||
except ImportError: | ||
return None | ||
|
||
|
||
def convert_model_with_mapping(model, mapping=None): | ||
"""Process mapping to quant config.""" | ||
# key is torch module, TODO: support more key format, like layer name. | ||
for key in mapping: | ||
# TODO: support more torch modules | ||
if key == nn.Linear: | ||
quant_cfg = get_quant_config_with_scheme(mapping[key]) | ||
if quant_cfg is None: | ||
continue | ||
replace_with_quant_linear(model, quant_cfg) | ||
|
||
replaced_modules = sum(isinstance(m, TensorQuantizer) for _, m in model.named_modules()) | ||
print(f"Inserted {replaced_modules} quantizers") | ||
|
||
|
||
def get_quant_config(scheme: str) -> dict[str, Any]: | ||
"""Generate quantization config for a torch model. | ||
|
||
Args: | ||
model: The PyTorch model to analyze | ||
|
||
Returns: | ||
Dictionary containing the quantization configuration | ||
""" | ||
|
||
# TODO: support more quant config | ||
try: | ||
from auto_round.export.export_to_llmcompressor.config import initialize_quantization | ||
|
||
quantization_config = initialize_quantization(scheme=scheme) | ||
quantization_config = quantization_config.to_dict() | ||
quantization_config["provider"] = "auto-round" | ||
quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] = True | ||
quantization_config["config_groups"]["group_0"]["input_activations"]["is_mx"] = True | ||
|
||
except ImportError: | ||
quantization_config = None | ||
|
||
return quantization_config | ||
|
||
|
||
def get_quantization_format(module) -> str | None: | ||
"""Gets the quantization string. | ||
|
||
Gets the quantization string by iterating through the module and its children. | ||
The first non-None quantization string is returned. | ||
""" | ||
|
||
def _get_quantization_from_layer(layer): | ||
weight_quantizer = getattr(layer, "weight_quantizer", None) | ||
input_quantizer = getattr(layer, "input_quantizer", None) | ||
|
||
if weight_quantizer is None or weight_quantizer._disabled: | ||
return None | ||
|
||
# TODO: support more quant format | ||
if weight_quantizer.num_bits == 8 and weight_quantizer.data_type == "mx_fp8": | ||
return "MXFP8" | ||
|
||
# Raise error for unsupported num_bits | ||
raise NotImplementedError(f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}") | ||
|
||
quantization = _get_quantization_from_layer(module) | ||
if quantization is not None: | ||
return quantization | ||
|
||
for _, layer in module.named_children(): | ||
format = get_quantization_format(layer) | ||
if format is not None: | ||
return format | ||
|
||
return None | ||
|
||
|
||
def is_quantlinear(module: nn.Module) -> bool: | ||
"""Returns whether the module is a quantized linear layer.""" | ||
return "QuantLinear" in type(module).__name__ |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.