Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPTQ Gradual Activation Quantization #1210

Merged
merged 6 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime
pip install pytest
- name: Run unittests
run: |
python -m unittest discover tests/pytorch_tests -v
pytest tests_pytest/pytorch
22 changes: 17 additions & 5 deletions model_compression_toolkit/gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,20 @@
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GPTQHessianScoresConfig
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
from model_compression_toolkit.gptq.common.gptq_config import (
GradientPTQConfig,
RoundingType,
GPTQHessianScoresConfig,
GradualActivationQuantizationConfig,
LinearAnnealingConfig
)

from model_compression_toolkit.verify_packages import FOUND_TF, FOUND_TORCH

if FOUND_TF:
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config

if FOUND_TORCH:
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
144 changes: 69 additions & 75 deletions model_compression_toolkit/gptq/common/gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable, Any, Dict
from typing import Callable, Any, Dict, Optional

from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
Expand All @@ -32,91 +33,84 @@ class RoundingType(Enum):
SoftQuantizer = 1


@dataclass
class GPTQHessianScoresConfig:
"""
Configuration to use for computing the Hessian-based scores for GPTQ loss metric.

Args:
hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores.
norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
"""
hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES
norm_scores: bool = True
log_norm: bool = True
scale_log_norm: bool = False
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE


def __init__(self,
hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES,
norm_scores: bool = True,
log_norm: bool = True,
scale_log_norm: bool = False,
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE):
@dataclass
class LinearAnnealingConfig:
"""
Config for Gradual Activation Quantization factor annealing. Factor refers to the weight of the float tensor.
ofirgo marked this conversation as resolved.
Show resolved Hide resolved
"""
initial_factor: float = 1
ofirgo marked this conversation as resolved.
Show resolved Hide resolved
target_factor: float = 0
start_step: int = 0 # gradient step to begin annealing
end_step: Optional[int] = None # gradient step to complete annealing. None means the last step.

"""
Initialize a GPTQHessianWeightsConfig.
def __post_init__(self):
if not (0 <= self.target_factor < self.initial_factor <= 1):
raise ValueError(f'Expected 0 <= target_factor < initial_factor <= 1, '
ofirgo marked this conversation as resolved.
Show resolved Hide resolved
f'received initial_factor {self.initial_factor} and target_factor {self.target_factor}')
if self.start_step < 0:
raise ValueError(f'Expected start_step >= 0. received {self.start_step}')
if self.end_step is not None and self.end_step <= self.start_step:
raise ValueError('Expected start_step < end_step, '
'received end_step {self.end_step} and start_step {self.start_stap}')

Args:
hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores.
norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
"""

self.hessians_num_samples = hessians_num_samples
self.norm_scores = norm_scores
self.log_norm = log_norm
self.scale_log_norm = scale_log_norm
self.hessian_batch_size = hessian_batch_size
@dataclass
class GradualActivationQuantizationConfig:
ofirgo marked this conversation as resolved.
Show resolved Hide resolved
annealing_policy: LinearAnnealingConfig = field(default_factory=LinearAnnealingConfig)


@dataclass
class GradientPTQConfig:
"""
Configuration to use for quantization with GradientPTQ.
"""
def __init__(self,
n_epochs: int,
optimizer: Any,
optimizer_rest: Any = None,
loss: Callable = None,
log_function: Callable = None,
train_bias: bool = True,
rounding_type: RoundingType = RoundingType.SoftQuantizer,
use_hessian_based_weights: bool = True,
optimizer_quantization_parameter: Any = None,
optimizer_bias: Any = None,
regularization_factor: float = REG_DEFAULT,
hessian_weights_config: GPTQHessianScoresConfig = GPTQHessianScoresConfig(),
gptq_quantizer_params_override: Dict[str, Any] = None):
"""
Initialize a GradientPTQConfig.

Args:
n_epochs (int): Number of representative dataset epochs to train.
optimizer (Any): Optimizer to use.
optimizer_rest (Any): Optimizer to use for bias and quantizer parameters.
loss (Callable): The loss to use. should accept 6 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors,
the 3rd is a list of quantized weights, the 4th is a list of float weights, the 5th and 6th lists are the mean and std of the tensors
accordingly. see example in multiple_tensors_mse_loss
log_function (Callable): Function to log information about the GPTQ process.
train_bias (bool): Whether to update the bias during the training or not.
rounding_type (RoundingType): An enum that defines the rounding type.
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
regularization_factor (float): A floating point number that defines the regularization factor.
hessian_weights_config (GPTQHessianScoresConfig): A configuration that include all necessary arguments to run a computation of Hessian scores for the GPTQ loss.
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).

"""

self.n_epochs = n_epochs
self.optimizer = optimizer
self.optimizer_rest = optimizer_rest
self.loss = loss
self.log_function = log_function
self.train_bias = train_bias

self.rounding_type = rounding_type
self.use_hessian_based_weights = use_hessian_based_weights
self.optimizer_quantization_parameter = optimizer_quantization_parameter
self.optimizer_bias = optimizer_bias
self.regularization_factor = regularization_factor
self.hessian_weights_config = hessian_weights_config

self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
else gptq_quantizer_params_override


Args:
n_epochs: Number of representative dataset epochs to train.
optimizer: Optimizer to use.
optimizer_rest: Optimizer to use for bias and quantizer parameters.
loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface.
log_function: Function to log information about the GPTQ process.
train_bias: Whether to update the bias during the training or not.
rounding_type: An enum that defines the rounding type.
use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss.
optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
optimizer_bias: Optimizer to override the rest optimizer for bias.
regularization_factor: A floating point number that defines the regularization factor.
hessian_weights_config: A configuration that include all necessary arguments to run a computation of
Hessian scores for the GPTQ loss.
gradual_activation_quantization_config: A configuration for Gradual Activation Quantization.
gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation.
"""
n_epochs: int
optimizer: Any
optimizer_rest: Any = None
loss: Callable = None
log_function: Callable = None
train_bias: bool = True
rounding_type: RoundingType = RoundingType.SoftQuantizer
use_hessian_based_weights: bool = True
optimizer_quantization_parameter: Any = None
optimizer_bias: Any = None
regularization_factor: float = REG_DEFAULT
hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig)
gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None
gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict)
27 changes: 18 additions & 9 deletions model_compression_toolkit/gptq/pytorch/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import torch

from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \
get_gradual_activation_quantizer_wrapper_factory
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
Expand All @@ -36,6 +38,7 @@
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
from model_compression_toolkit.trainable_infrastructure.pytorch.util import get_total_grad_steps


class PytorchGPTQTrainer(GPTQTrainer):
Expand Down Expand Up @@ -66,6 +69,13 @@ def __init__(self,
representative_data_gen: Dataset to use for inputs of the models.
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
"""
def _get_total_grad_steps():
return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs

# must be set prior to model building in the base class constructor
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
gptq_config, _get_total_grad_steps)

super().__init__(graph_float,
graph_quant,
gptq_config,
Expand Down Expand Up @@ -98,7 +108,7 @@ def __init__(self,

self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights())

self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)

def _is_gptq_weights_trainable(self,
node: BaseNode) -> bool:
Expand Down Expand Up @@ -145,21 +155,20 @@ def gptq_wrapper(self,
def get_activation_quantizer_holder(self, n: BaseNode) -> Callable:
"""
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
If the layer is not supposed to be wrapped with an activation quantizer - return None.
Args:
n: Node to attach a PytorchActivationQuantizationHolder to its output.
Returns:
A PytorchActivationQuantizationHolder module for the node's activation quantization.
"""
_, activation_quantizers = quantization_builder(n, self.gptq_config)
# Holder by definition uses a single quantizer for the activation quantization
# thus we make sure this is the only possible case (unless it's a node we no activation
# quantization, which in this case has an empty list).
if len(activation_quantizers) == 1:
return PytorchActivationQuantizationHolder(activation_quantizers[0])
Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, "
f"but {len(activation_quantizers)} were found for node {n.name}. "
f"Ensure the node is configured with a single activation quantizer.")
# thus we make sure this is the only possible case
if len(activation_quantizers) != 1:
Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, "
f"but {len(activation_quantizers)} were found for node {n.name}. "
f"Ensure the node is configured with a single activation quantizer.")
quantizer = self.gradual_act_quantizer_wrapper_factory(activation_quantizers[0])
return PytorchActivationQuantizationHolder(quantizer)

def build_gptq_model(self):
"""
Expand Down
Loading
Loading