diff --git a/.github/workflows/run_pytorch_tests.yml b/.github/workflows/run_pytorch_tests.yml index fa6a7d698..9d2913161 100644 --- a/.github/workflows/run_pytorch_tests.yml +++ b/.github/workflows/run_pytorch_tests.yml @@ -24,10 +24,11 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime + pip install pytest - name: Run unittests run: | python -m unittest discover tests/pytorch_tests -v - + pytest tests_pytest/pytorch diff --git a/.github/workflows/run_tests_suite_coverage.yml b/.github/workflows/run_tests_suite_coverage.yml index 23db7e109..42137999b 100644 --- a/.github/workflows/run_tests_suite_coverage.yml +++ b/.github/workflows/run_tests_suite_coverage.yml @@ -24,6 +24,7 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install coverage + pip install pytest - name: Prepare TF env run: pip install tensorflow==2.13.* - name: Run tensorflow testsuite @@ -32,6 +33,8 @@ jobs: run: pip uninstall tensorflow -y && pip install torch==2.0.* torchvision onnx onnxruntime onnxruntime-extensions - name: Run torch testsuite run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest tests/test_suite.py -v + - name: Run torch pytest + run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/pytorch - name: Combine Multiple Coverage Files run: coverage combine - name: Run Coverage HTML diff --git a/model_compression_toolkit/gptq/__init__.py b/model_compression_toolkit/gptq/__init__.py index 899709817..a95bc9f91 100644 --- a/model_compression_toolkit/gptq/__init__.py +++ b/model_compression_toolkit/gptq/__init__.py @@ -13,8 +13,20 @@ # limitations under the License. # ============================================================================== -from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GPTQHessianScoresConfig -from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization -from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config -from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization -from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config \ No newline at end of file +from model_compression_toolkit.gptq.common.gptq_config import ( + GradientPTQConfig, + RoundingType, + GPTQHessianScoresConfig, + GradualActivationQuantizationConfig, + QFractionLinearAnnealingConfig +) + +from model_compression_toolkit.verify_packages import FOUND_TF, FOUND_TORCH + +if FOUND_TF: + from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization + from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config + +if FOUND_TORCH: + from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization + from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config \ No newline at end of file diff --git a/model_compression_toolkit/gptq/common/gptq_config.py b/model_compression_toolkit/gptq/common/gptq_config.py index b15eb8eb1..dcd806a93 100644 --- a/model_compression_toolkit/gptq/common/gptq_config.py +++ b/model_compression_toolkit/gptq/common/gptq_config.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import dataclass, field from enum import Enum -from typing import Callable, Any, Dict +from typing import Callable, Any, Dict, Optional from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT @@ -32,91 +33,103 @@ class RoundingType(Enum): SoftQuantizer = 1 +@dataclass class GPTQHessianScoresConfig: """ Configuration to use for computing the Hessian-based scores for GPTQ loss metric. + + Args: + hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores. + norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1). + log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores. + scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores. + hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective. """ + hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES + norm_scores: bool = True + log_norm: bool = True + scale_log_norm: bool = False + hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE - def __init__(self, - hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES, - norm_scores: bool = True, - log_norm: bool = True, - scale_log_norm: bool = False, - hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE): - """ - Initialize a GPTQHessianWeightsConfig. +@dataclass +class QFractionLinearAnnealingConfig: + """ + Config for the quantized fraction linear scheduler of Gradual Activation Quantization. - Args: - hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores. - norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1). - log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores. - scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores. - hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective. - """ + Args: + initial_q_fraction: initial quantized fraction + target_q_fraction: target quantized fraction + start_step: gradient step to begin annealing + end_step: gradient step to complete annealing. None means last step. + """ + initial_q_fraction: float + target_q_fraction: float + start_step: int + end_step: Optional[int] - self.hessians_num_samples = hessians_num_samples - self.norm_scores = norm_scores - self.log_norm = log_norm - self.scale_log_norm = scale_log_norm - self.hessian_batch_size = hessian_batch_size + def __post_init__(self): + if not (0 <= self.initial_q_fraction < self.target_q_fraction <= 1): + raise ValueError(f'Expected 0 <= initial_q_fraction < target_q_fraction <= 1, received initial_q_fraction ' + f'{self.initial_q_fraction} and target_q_fraction {self.target_q_fraction}.') + if self.start_step < 0: + raise ValueError(f'Expected start_step >= 0. received {self.start_step}.') + if self.end_step is not None and self.end_step <= self.start_step: + raise ValueError('Expected start_step < end_step, ' + 'received end_step {self.end_step} and start_step {self.start_stap}.') -class GradientPTQConfig: - """ - Configuration to use for quantization with GradientPTQ. - """ - def __init__(self, - n_epochs: int, - optimizer: Any, - optimizer_rest: Any = None, - loss: Callable = None, - log_function: Callable = None, - train_bias: bool = True, - rounding_type: RoundingType = RoundingType.SoftQuantizer, - use_hessian_based_weights: bool = True, - optimizer_quantization_parameter: Any = None, - optimizer_bias: Any = None, - regularization_factor: float = REG_DEFAULT, - hessian_weights_config: GPTQHessianScoresConfig = GPTQHessianScoresConfig(), - gptq_quantizer_params_override: Dict[str, Any] = None): - """ - Initialize a GradientPTQConfig. +@dataclass +class GradualActivationQuantizationConfig: + """ Configuration for Gradual Activation Quantization. + + By default, the quantized fraction increases linearly from 0 to 1 throughout the training. Args: - n_epochs (int): Number of representative dataset epochs to train. - optimizer (Any): Optimizer to use. - optimizer_rest (Any): Optimizer to use for bias and quantizer parameters. - loss (Callable): The loss to use. should accept 6 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, - the 3rd is a list of quantized weights, the 4th is a list of float weights, the 5th and 6th lists are the mean and std of the tensors - accordingly. see example in multiple_tensors_mse_loss - log_function (Callable): Function to log information about the GPTQ process. - train_bias (bool): Whether to update the bias during the training or not. - rounding_type (RoundingType): An enum that defines the rounding type. - use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss. - optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters. - optimizer_bias (Any): Optimizer to override the rest optimizer for bias. - regularization_factor (float): A floating point number that defines the regularization factor. - hessian_weights_config (GPTQHessianScoresConfig): A configuration that include all necessary arguments to run a computation of Hessian scores for the GPTQ loss. - gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters). - - """ - - self.n_epochs = n_epochs - self.optimizer = optimizer - self.optimizer_rest = optimizer_rest - self.loss = loss - self.log_function = log_function - self.train_bias = train_bias - - self.rounding_type = rounding_type - self.use_hessian_based_weights = use_hessian_based_weights - self.optimizer_quantization_parameter = optimizer_quantization_parameter - self.optimizer_bias = optimizer_bias - self.regularization_factor = regularization_factor - self.hessian_weights_config = hessian_weights_config - - self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \ - else gptq_quantizer_params_override + q_fraction_scheduler_policy: config for the scheduling of the quantized fraction. + Only linear annealing is currently supported. + """ + q_fraction_scheduler_policy: QFractionLinearAnnealingConfig = field( + default_factory=lambda: QFractionLinearAnnealingConfig(initial_q_fraction=0, + target_q_fraction=1, + start_step=0, + end_step=None) + ) +@dataclass +class GradientPTQConfig: + """ + Configuration to use for quantization with GradientPTQ. + + Args: + n_epochs: Number of representative dataset epochs to train. + optimizer: Optimizer to use. + optimizer_rest: Optimizer to use for bias and quantizer parameters. + loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface. + log_function: Function to log information about the GPTQ process. + train_bias: Whether to update the bias during the training or not. + rounding_type: An enum that defines the rounding type. + use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss. + optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters. + optimizer_bias: Optimizer to override the rest optimizer for bias. + regularization_factor: A floating point number that defines the regularization factor. + hessian_weights_config: A configuration that include all necessary arguments to run a computation of + Hessian scores for the GPTQ loss. + gradual_activation_quantization_config: A configuration for Gradual Activation Quantization. + gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation. + """ + n_epochs: int + optimizer: Any + optimizer_rest: Any = None + loss: Callable = None + log_function: Callable = None + train_bias: bool = True + rounding_type: RoundingType = RoundingType.SoftQuantizer + use_hessian_based_weights: bool = True + optimizer_quantization_parameter: Any = None + optimizer_bias: Any = None + regularization_factor: float = REG_DEFAULT + hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig) + gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None + gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict) diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index dcb08c0bf..d5a98ac93 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -21,6 +21,8 @@ import torch from model_compression_toolkit.core.common.hessian import HessianInfoService +from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \ + get_gradual_activation_quantizer_wrapper_factory from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq @@ -36,6 +38,7 @@ from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder +from model_compression_toolkit.trainable_infrastructure.pytorch.util import get_total_grad_steps class PytorchGPTQTrainer(GPTQTrainer): @@ -66,6 +69,13 @@ def __init__(self, representative_data_gen: Dataset to use for inputs of the models. hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model. """ + def _get_total_grad_steps(): + return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs + + # must be set prior to model building in the base class constructor + self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory( + gptq_config, _get_total_grad_steps) + super().__init__(graph_float, graph_quant, gptq_config, @@ -98,7 +108,7 @@ def __init__(self, self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights()) - self.reg_func = get_regularization(self.gptq_config, representative_data_gen) + self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps) def _is_gptq_weights_trainable(self, node: BaseNode) -> bool: @@ -145,7 +155,6 @@ def gptq_wrapper(self, def get_activation_quantizer_holder(self, n: BaseNode) -> Callable: """ Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node. - If the layer is not supposed to be wrapped with an activation quantizer - return None. Args: n: Node to attach a PytorchActivationQuantizationHolder to its output. Returns: @@ -153,13 +162,13 @@ def get_activation_quantizer_holder(self, n: BaseNode) -> Callable: """ _, activation_quantizers = quantization_builder(n, self.gptq_config) # Holder by definition uses a single quantizer for the activation quantization - # thus we make sure this is the only possible case (unless it's a node we no activation - # quantization, which in this case has an empty list). - if len(activation_quantizers) == 1: - return PytorchActivationQuantizationHolder(activation_quantizers[0]) - Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, " - f"but {len(activation_quantizers)} were found for node {n.name}. " - f"Ensure the node is configured with a single activation quantizer.") + # thus we make sure this is the only possible case + if len(activation_quantizers) != 1: + Logger.critical(f"'PytorchActivationQuantizationHolder' requires exactly one quantizer, " + f"but {len(activation_quantizers)} were found for node {n.name}. " + f"Ensure the node is configured with a single activation quantizer.") + quantizer = self.gradual_act_quantizer_wrapper_factory(activation_quantizers[0]) + return PytorchActivationQuantizationHolder(quantizer) def build_gptq_model(self): """ diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index d9b196740..96e74088c 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -13,26 +13,26 @@ # limitations under the License. # ============================================================================== import copy +from typing import Callable, Union -from typing import Callable -from model_compression_toolkit.core import common -from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE -from model_compression_toolkit.verify_packages import FOUND_TORCH +from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH +from model_compression_toolkit.core import CoreConfig +from model_compression_toolkit.core.analyzer import analyzer_model_quantization +from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ + MixedPrecisionQuantizationConfig +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ + ResourceUtilization from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer -from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT -from model_compression_toolkit.logger import Logger -from model_compression_toolkit.constants import PYTORCH -from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig -from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities -from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core.runner import core_runner +from model_compression_toolkit.gptq.common.gptq_config import ( + GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig) +from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM from model_compression_toolkit.gptq.runner import gptq_runner -from model_compression_toolkit.core.analyzer import analyzer_model_quantization -from model_compression_toolkit.core import CoreConfig -from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ - MixedPrecisionQuantizationConfig -from model_compression_toolkit.metadata import get_versions_dict, create_model_metadata +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.metadata import create_model_metadata +from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities +from model_compression_toolkit.verify_packages import FOUND_TORCH LR_DEFAULT = 1e-4 LR_REST_DEFAULT = 1e-4 @@ -53,33 +53,38 @@ DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL) def get_pytorch_gptq_config(n_epochs: int, - optimizer: Optimizer = Adam([torch.Tensor([])], lr=LR_DEFAULT), - optimizer_rest: Optimizer = Adam([torch.Tensor([])], lr=LR_REST_DEFAULT), + optimizer: Optimizer = None, + optimizer_rest: Optimizer = None, loss: Callable = multiple_tensors_mse_loss, log_function: Callable = None, use_hessian_based_weights: bool = True, regularization_factor: float = REG_DEFAULT, - hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE + hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE, + gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False, ) -> GradientPTQConfig: """ - Create a GradientPTQConfigV2 instance for Pytorch models. + Create a GradientPTQConfig instance for Pytorch models. args: n_epochs (int): Number of epochs for running the representative dataset for fine-tuning. optimizer (Optimizer): Pytorch optimizer to use for fine-tuning for auxiliry variable. optimizer_rest (Optimizer): Pytorch optimizer to use for fine-tuning of the bias variable. - loss (Callable): loss to use during fine-tuning. should accept 4 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors, the 3rd is a list of quantized weights and the 4th is a list of float weights. + loss (Callable): loss to use during fine-tuning. See the default loss function for the exact interface. log_function (Callable): Function to log information about the gptq process. use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss. regularization_factor (float): A floating point number that defines the regularization factor. hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ. + gradual_activation_quantization (bool, GradualActivationQuantizationConfig): + If False, GradualActivationQuantization is disabled. + If True, GradualActivationQuantization is enabled with the default settings. + GradualActivationQuantizationConfig object can be passed to use non-default settings. returns: - a GradientPTQConfigV2 object to use when fine-tuning the quantized model using gptq. + a GradientPTQConfig object to use when fine-tuning the quantized model using gptq. Examples: - Import MCT and Create a GradientPTQConfigV2 to run for 5 epochs: + Import MCT and Create a GradientPTQConfig to run for 5 epochs: >>> import model_compression_toolkit as mct >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5) @@ -89,16 +94,31 @@ def get_pytorch_gptq_config(n_epochs: int, >>> import torch >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)])) - The configuration can be passed to :func:`~model_compression_toolkit.pytorch_post_training_quantization` in order to quantize a pytorch model using gptq. + To enable Gradual Activation Quantization with non-default settings build GradualActivationQuantizationConfig: + >>> gradual_act_conf = mct.gptq.GradualActivationQuantizationConfig(mct.gptq.QFractionLinearAnnealingConfig(initial_q_fraction=0.2)) + >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, gradual_activation_quantization=gradual_act_conf) + The configuration can be passed to :func:`~model_compression_toolkit.pytorch_gradient_post_training_quantization` in order to quantize a pytorch model using gptq. """ + optimizer = optimizer or Adam([torch.Tensor([])], lr=LR_DEFAULT) + optimizer_rest = optimizer_rest or Adam([torch.Tensor([])], lr=LR_REST_DEFAULT) + bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM) + + if isinstance(gradual_activation_quantization, bool): + gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None + elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig): + gradual_quant_config = gradual_activation_quantization + else: + raise TypeError(f'gradual_activation_quantization argument should be bool or ' + f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') # pragma: no cover + return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss, log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, use_hessian_based_weights=use_hessian_based_weights, regularization_factor=regularization_factor, - hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)) - + hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size), + gradual_activation_quantization_config=gradual_quant_config) def pytorch_gradient_post_training_quantization(model: Module, representative_data_gen: Callable, @@ -222,11 +242,11 @@ def pytorch_gradient_post_training_quantization(model: Module, else: # If torch is not installed, # we raise an exception when trying to use these functions. - def get_pytorch_gptq_config(*args, **kwargs): + def get_pytorch_gptq_config(*args, **kwargs): # pragma: no cover Logger.critical("PyTorch must be installed to use 'get_pytorch_gptq_config'. " - "The 'torch' package is missing.") # pragma: no cover + "The 'torch' package is missing.") - def pytorch_gradient_post_training_quantization(*args, **kwargs): + def pytorch_gradient_post_training_quantization(*args, **kwargs): # pragma: no cover Logger.critical("PyTorch must be installed to use 'pytorch_gradient_post_training_quantization'. " - "The 'torch' package is missing.") # pragma: no cover + "The 'torch' package is missing.") diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py b/model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py new file mode 100644 index 000000000..19231e2b5 --- /dev/null +++ b/model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py @@ -0,0 +1,80 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import partial +from typing import Callable + +from model_compression_toolkit.gptq import GradientPTQConfig, QFractionLinearAnnealingConfig +from model_compression_toolkit.trainable_infrastructure import BasePytorchTrainableQuantizer + +from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler + + +def get_gradual_activation_quantizer_wrapper_factory(gptq_config: GradientPTQConfig, + get_total_grad_steps_fn: Callable[[], int]) \ + -> Callable[[BasePytorchTrainableQuantizer], 'GradualActivationQuantizerWrapper']: + """ + Get a factory for 'GradualActivationQuantizerWrapper'. + + Args: + gptq_config: GPTQ configuration. + get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps. + + Returns: + A factory function to build 'GradualActivationQuantizerWrapper' from Quantizer. + """ + if gptq_config.gradual_activation_quantization_config is None: + return lambda q: q + + annealing_cfg = gptq_config.gradual_activation_quantization_config.q_fraction_scheduler_policy + if isinstance(annealing_cfg, QFractionLinearAnnealingConfig): + t_end = annealing_cfg.end_step or get_total_grad_steps_fn() + factor_scheduler = LinearAnnealingScheduler(t_start=annealing_cfg.start_step, t_end=t_end, + initial_val=annealing_cfg.initial_q_fraction, + target_val=annealing_cfg.target_q_fraction) + else: + raise ValueError(f'Unknown annealing policy {annealing_cfg}') + + return partial(GradualActivationQuantizerWrapper, q_fraction_scheduler=factor_scheduler) + + +class GradualActivationQuantizerWrapper: + # TODO update paper's url + """ + Quantizer wrapper for Gradual Activation Quantization training (https://arxiv.org/abs/2309.11531). + + It computes the weighted sum of the float activation 'x' and the quantized activation 'q(x)': + + out = (1 - q_fraction) * x + q_fraction * q(x) + + where 'q_fraction' is a tensor fraction to quantize in the range [0, 1] provided by a scheduler. + + Args: + quantizer: quantizer to wrap. + q_fraction_scheduler: a callable that accepts a gradient step and returns the corresponding quantized fraction. + """ + def __init__(self, quantizer: BasePytorchTrainableQuantizer, q_fraction_scheduler: Callable[[int], float]): + self.quantizer = quantizer + self.q_fraction_scheduler = q_fraction_scheduler + self.step_cnt = 0 + + def __call__(self, x, training: bool = True): + q_fraction = self.q_fraction_scheduler(self.step_cnt) + out_q = self.quantizer(x, training) + out = (1 - q_fraction) * x + q_fraction * out_q + self.step_cnt += 1 + return out + + def initialize_quantization(self, *args, **kwargs): + self.quantizer.initialize_quantization(*args, **kwargs) diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py b/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py index 4c40ff807..e4aef7932 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py @@ -12,33 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from tqdm import tqdm from typing import Callable -from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig +from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \ SoftQuantizerRegularization +from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler -def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable: +WARMUP_STEP_FRACTION = 0.2 + +def get_regularization(gptq_config: GradientPTQConfig, get_total_grad_steps_fn: Callable[[], int]) -> Callable: """ Returns a function that computes the regularization term for GPTQ training based on the given rounding type in the GPTQ configuration. Args: gptq_config: A GPTQ configuration. - representative_data_gen: Dataset used for the GPTQ training. + get_total_grad_steps_fn: a callable to obtain the total expected number of gradient steps. Returns: A function for computing the regularization. If there is no regularization function defined for the given rounding type, then it returns a function that just returns 0. """ if gptq_config.rounding_type == RoundingType.SoftQuantizer: - # dry run on the representative dataset to count number of batches - num_batches = 0 - for _ in tqdm(representative_data_gen(), "GPTQ initialization"): - num_batches += 1 - - return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs) + total_gradient_steps = get_total_grad_steps_fn() + t_start = int(WARMUP_STEP_FRACTION * total_gradient_steps) + scheduler = LinearAnnealingScheduler(t_start=t_start, t_end=total_gradient_steps, initial_val=20, target_val=2) + return SoftQuantizerRegularization(scheduler) else: return lambda m, e_reg: 0 diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py index 972fd1284..b08c54faf 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py @@ -12,57 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import List +from typing import List, Callable import torch -import numpy as np from torch import nn +from mct_quantizers import PytorchQuantizationWrapper from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from model_compression_toolkit.core.pytorch.utils import to_torch_tensor from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq -from mct_quantizers import PytorchQuantizationWrapper - - -class LinearTempDecay: - """ - Annealing process for the soft quantizer regularization temperature term. - """ - - def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2): - """ - Initializes a LinearTempDecay object. - - Args: - t_max: maximal time step. - rel_start_decay: Decay step size at the beginning of the process. - start_b: Starting value of the regularization term. - end_b: Target value of the regularization term. - """ - - self.t_max = t_max - self.start_decay = rel_start_decay * t_max - self.start_b = start_b - self.end_b = end_b - - def __call__(self, t: float) -> float: - """ - Cosine annealing scheduler for soft quantizer regularization temperature term. - - Args: - t: The current time step. - - Returns: Scheduled temperature. - """ - - is_before_start_decay = (t < self.start_decay) - - rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) - - return self.start_b * is_before_start_decay + \ - (1 - is_before_start_decay) * \ - (self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])), - to_torch_tensor(np.array((1 - rel_t))))) class SoftQuantizerRegularization: @@ -70,16 +27,16 @@ class SoftQuantizerRegularization: A class to handle the computation of soft quantizer regularization for GPTQ training. """ - def __init__(self, total_gradient_steps: int): + def __init__(self, beta_scheduler: Callable[[int], float]): """ Initializes the regularization computation object with a LinearDecay object. Args: - total_gradient_steps: The number of gradient steps during optimization. + beta_scheduler: a callable that accepts current time step and returns a corresponding beta value. """ # Initializing the temperature decay according to the number of expected gradient steps - self.linear_decay = LinearTempDecay(total_gradient_steps) + self.beta_scheduler = beta_scheduler self.count_iter = 0 @@ -95,7 +52,7 @@ def __call__(self, model: nn.Module, entropy_reg: float): """ soft_reg_aux: List[torch.Tensor] = [] - b = self.linear_decay(self.count_iter) + b = self.beta_scheduler(self.count_iter) for layer in model.modules(): if isinstance(layer, PytorchQuantizationWrapper): kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer), diff --git a/model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py b/model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py new file mode 100644 index 000000000..d75bf9f7e --- /dev/null +++ b/model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py @@ -0,0 +1,39 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor + + +class LinearAnnealingScheduler: + def __init__(self, t_start: int, t_end: int, initial_val: float, target_val: float): + """ + Linear annealing scheduler. Returns the corresponding annealed value per time step. + + Args: + t_start: time step to begin annealing. + t_end: time step to complete annealing. + initial_val: initial value. + target_val: target value. + """ + if not (0 <= t_start < t_end): + raise ValueError(f'Expected 0 <= t_start < t_end, actual {t_end=} {t_start=}') + + self.t_start = t_start + self.t_end = t_end + self.initial_val = initial_val + self.target_val = target_val + + def __call__(self, t: int) -> float: + factor = to_torch_tensor((t - self.t_start) / (self.t_end - self.t_start)).clip(0, 1) + return self.initial_val + factor * (self.target_val - self.initial_val) diff --git a/model_compression_toolkit/trainable_infrastructure/pytorch/util.py b/model_compression_toolkit/trainable_infrastructure/pytorch/util.py new file mode 100644 index 000000000..fec5062e5 --- /dev/null +++ b/model_compression_toolkit/trainable_infrastructure/pytorch/util.py @@ -0,0 +1,29 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import cache +from typing import Callable + +from tqdm import tqdm + + +@cache +def get_total_grad_steps(representative_data_gen: Callable) -> int: + # dry run on the representative dataset to count number of batches + num_batches = 0 + for _ in tqdm(representative_data_gen(), "Estimating representative dataset size"): + num_batches += 1 + return num_batches + + diff --git a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py index 53033e18d..685b7f3da 100644 --- a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py +++ b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py @@ -9,14 +9,19 @@ from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO +from model_compression_toolkit.gptq import GradualActivationQuantizationConfig, QFractionLinearAnnealingConfig from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer +from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \ + GradualActivationQuantizerWrapper from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc from model_compression_toolkit.trainable_infrastructure import TrainingMethod from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import \ STESymmetricActivationTrainableQuantizer +from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_quantization_parameters +from tests.pytorch_tests.utils import get_layers_from_model_by_type INPUT_SHAPE = [3, 8, 8] @@ -67,12 +72,7 @@ class TestGPTQModelBuilderWithActivationHolder(unittest.TestCase): def test_adding_holder_instead_quantize_wrapper(self): gptq_model = self._get_gptq_model(INPUT_SHAPE, BasicModel()) - last_module = list(gptq_model.named_modules())[-1][1] - activation_quantization_holders_in_model = [m[1] for m in gptq_model.named_modules() if isinstance(m[1], PytorchActivationQuantizationHolder)] - # the last module should be an activation quantization holder - self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder)) - # check that 4 activation quantization holders where generated - self.assertTrue(len(activation_quantization_holders_in_model) == 3) + activation_quantization_holders_in_model = self._get_holders_with_validation(gptq_model, exp_n_holders=3) for a in activation_quantization_holders_in_model: self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer)) self.assertEquals(a.activation_holder_quantizer.identifier, TrainingMethod.STE) @@ -86,12 +86,7 @@ def test_adding_holder_instead_quantize_wrapper(self): def test_adding_holder_after_relu(self): gptq_model = self._get_gptq_model(INPUT_SHAPE, ReLUModel()) - last_module = list(gptq_model.named_modules())[-1][1] - activation_quantization_holders_in_model = [m[1] for m in gptq_model.named_modules() if isinstance(m[1], PytorchActivationQuantizationHolder)] - # the last module should be an activation quantization holder - self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder)) - # check that 3 activation quantization holders where generated - self.assertTrue(len(activation_quantization_holders_in_model) == 3) + activation_quantization_holders_in_model = self._get_holders_with_validation(gptq_model, exp_n_holders=3) for a in activation_quantization_holders_in_model: self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer)) for name, module in gptq_model.named_modules(): @@ -101,12 +96,7 @@ def test_adding_holder_after_relu(self): def test_adding_holders_after_reuse(self): float_model = ReuseModel() gptq_model = self._get_gptq_model(INPUT_SHAPE, float_model) - activation_quantization_holders_in_model = [m[1] for m in gptq_model.named_modules() if isinstance(m[1], PytorchActivationQuantizationHolder)] - last_module = list(gptq_model.named_modules())[-1][1] - # the last module should be an activation quantization holder - self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder)) - # check that 4 activation quantization holders where generated - self.assertTrue(len(activation_quantization_holders_in_model) == 3) + activation_quantization_holders_in_model = self._get_holders_with_validation(gptq_model, exp_n_holders=3) for a in activation_quantization_holders_in_model: self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer)) for name, module in gptq_model.named_modules(): @@ -121,7 +111,39 @@ def test_adding_holders_after_reuse(self): # self.assertTrue(list(fx_model.graph.nodes)[3].all_input_nodes[0] == list(fx_model.graph.nodes)[2]) # self.assertTrue(list(fx_model.graph.nodes)[6].all_input_nodes[0] == list(fx_model.graph.nodes)[5]) - def _get_gptq_model(self, input_shape, in_model): + def test_adding_holder_with_gradual_act_quantization(self): + gradual_act_quant_cfg = GradualActivationQuantizationConfig( + QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=0.9, start_step=100, end_step=500) + ) + gptq_cfg = mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False, + gradual_activation_quantization=gradual_act_quant_cfg) + gptq_model = self._get_gptq_model(INPUT_SHAPE, BasicModel(), gptq_cfg) + activation_holders = self._get_holders_with_validation(gptq_model, exp_n_holders=3) + + for a in activation_holders: + self.assertTrue(isinstance(a.activation_holder_quantizer, GradualActivationQuantizerWrapper)) + # check that quantizer wrapper's scheduler was created according to gptq config + factor_scheduler = a.activation_holder_quantizer.q_fraction_scheduler + self.assertTrue(isinstance(factor_scheduler, LinearAnnealingScheduler)) + self.assertEqual(factor_scheduler.t_start, 100) + self.assertEqual(factor_scheduler.t_end, 500) + self.assertEqual(factor_scheduler.initial_val, 0.1) + self.assertEqual(factor_scheduler.target_val, 0.9) + # check the wrapped quantizer is correct and frozen + quantizer = a.activation_holder_quantizer.quantizer + self.assertTrue(isinstance(quantizer, STESymmetricActivationTrainableQuantizer)) + self.assertTrue(quantizer.freeze_quant_params is True) + self.assertEquals(quantizer.get_trainable_variables(VariableGroup.QPARAMS), []) + + def _get_holders_with_validation(self, gptq_model, exp_n_holders): + last_module = list(gptq_model.named_modules())[-1][1] + activation_quantization_holders = get_layers_from_model_by_type(gptq_model, PytorchActivationQuantizationHolder) + # the last module should be an activation quantization holder + self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder)) + self.assertTrue(len(activation_quantization_holders) == exp_n_holders) + return activation_quantization_holders + + def _get_gptq_model(self, input_shape, in_model, gptq_cfg=None): pytorch_impl = GPTQPytorchImplemantation() qc = copy.deepcopy(mct.core.DEFAULTCONFIG) qc.linear_collapsing = False @@ -135,11 +157,12 @@ def _get_gptq_model(self, input_shape, in_model): qc=qc) graph = set_bit_widths(mixed_precision_enable=False, graph=graph) + gptq_cfg = gptq_cfg or mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False) trainer = PytorchGPTQTrainer(graph, - graph, - mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False), - pytorch_impl, - DEFAULT_PYTORCH_INFO, - representative_dataset) + graph, + gptq_cfg, + pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset) gptq_model, _ = trainer.build_gptq_model() - return gptq_model \ No newline at end of file + return gptq_model diff --git a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py index 7561d7afa..082aaac26 100644 --- a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py @@ -26,7 +26,7 @@ from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest import model_compression_toolkit as mct from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GradientPTQConfig, RoundingType, \ - GPTQHessianScoresConfig + GPTQHessianScoresConfig, GradualActivationQuantizationConfig from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc @@ -57,7 +57,7 @@ class GPTQBaseTest(BasePytorchFeatureNetworkTest): def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationMethod.SYMMETRIC, rounding_type=RoundingType.STE, per_channel=True, hessian_weights=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True, - num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES): + num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, gradual_activation_quantization=False): super().__init__(unit_test, input_shape=(3, 16, 16), num_calibration_iter=num_calibration_iter) self.seed = 0 self.rounding_type = rounding_type @@ -70,6 +70,7 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM self.override_params = {QUANT_PARAM_LEARNING_STR: params_learning} if \ rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict(default_value=1)} \ if rounding_type == RoundingType.STE else None + self.gradual_activation_quantization = gradual_activation_quantization def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING, @@ -113,11 +114,13 @@ def run_test(self): # Compare self.gptq_compare(ptq_model, gptq_model, input_x=x) + return gptq_model class GPTQAccuracyTest(GPTQBaseTest): def get_gptq_config(self): + gradual_act_cfg = GradualActivationQuantizationConfig() if self.gradual_activation_quantization else None return GradientPTQConfig(5, optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4), optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=1e-4), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, @@ -125,7 +128,8 @@ def get_gptq_config(self): optimizer_bias=torch.optim.Adam([torch.Tensor([])], lr=0.4), hessian_weights_config=GPTQHessianScoresConfig(log_norm=self.log_norm_weights, scale_log_norm=self.scaled_log_norm), - gptq_quantizer_params_override=self.override_params) + gptq_quantizer_params_override=self.override_params, + gradual_activation_quantization_config=gradual_act_cfg) def gptq_compare(self, ptq_model, gptq_model, input_x=None): ptq_weights = torch_tensor_to_numpy(list(ptq_model.parameters())) @@ -137,9 +141,11 @@ def gptq_compare(self, ptq_model, gptq_model, input_x=None): class GPTQWeightsUpdateTest(GPTQBaseTest): def get_gptq_config(self): + gradual_act_cfg = GradualActivationQuantizationConfig() if self.gradual_activation_quantization else None return GradientPTQConfig(50, optimizer=torch.optim.Adam([torch.Tensor([])], lr=0.5), optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0.5), loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + gradual_activation_quantization_config=gradual_act_cfg, gptq_quantizer_params_override=self.override_params) def compare(self, ptq_model, gptq_model, input_x=None, max_change=None): @@ -158,9 +164,11 @@ def compare(self, ptq_model, gptq_model, input_x=None, max_change=None): class GPTQLearnRateZeroTest(GPTQBaseTest): def get_gptq_config(self): + gradual_act_cfg = GradualActivationQuantizationConfig() if self.gradual_activation_quantization else None return GradientPTQConfig(5, optimizer=torch.optim.Adam([torch.Tensor([])], lr=0), optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=0), loss=multiple_tensors_mse_loss, train_bias=False, rounding_type=self.rounding_type, + gradual_activation_quantization_config=gradual_act_cfg, gptq_quantizer_params_override=self.override_params) def gptq_compare(self, ptq_model, gptq_model, input_x=None): diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 804958ee5..44b7d99d0 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -605,7 +605,6 @@ def test_gptq(self): per_channel=True, hessian_weights=True, log_norm_weights=True, scaled_log_norm=True).run_test() GPTQWeightsUpdateTest(self, rounding_type=RoundingType.SoftQuantizer).run_test() GPTQLearnRateZeroTest(self, rounding_type=RoundingType.SoftQuantizer).run_test() - GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, weights_quant_method=QuantizationMethod.UNIFORM).run_test() GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, @@ -618,6 +617,16 @@ def test_gptq(self): weights_quant_method=QuantizationMethod.UNIFORM, params_learning=False).run_test() # TODO: When params learning is True, the uniform quantizer gets a min value > max value + def test_gptq_with_gradual_activation(self): + """ + This test checks the GPTQ feature with gradual activation quantization. + """ + GPTQAccuracyTest(self, gradual_activation_quantization=True).run_test() + GPTQAccuracyTest(self, rounding_type=RoundingType.SoftQuantizer, + gradual_activation_quantization=True).run_test() + GPTQLearnRateZeroTest(self, rounding_type=RoundingType.SoftQuantizer, + gradual_activation_quantization=True).run_test() + def test_qat(self): """ This test checks the QAT feature. diff --git a/tests_pytest/__init__.py b/tests_pytest/__init__.py new file mode 100644 index 000000000..e11a7cc60 --- /dev/null +++ b/tests_pytest/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/pytorch/__init__.py b/tests_pytest/pytorch/__init__.py new file mode 100644 index 000000000..e11a7cc60 --- /dev/null +++ b/tests_pytest/pytorch/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/pytorch/gptq/__init__.py b/tests_pytest/pytorch/gptq/__init__.py new file mode 100644 index 000000000..e11a7cc60 --- /dev/null +++ b/tests_pytest/pytorch/gptq/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/pytorch/gptq/test_annealing_cfg.py b/tests_pytest/pytorch/gptq/test_annealing_cfg.py new file mode 100644 index 000000000..10ed4a61f --- /dev/null +++ b/tests_pytest/pytorch/gptq/test_annealing_cfg.py @@ -0,0 +1,40 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest + +from model_compression_toolkit.gptq import QFractionLinearAnnealingConfig + + +def test_linear_annealing_cfg_validation(): + with pytest.raises(ValueError, match='Expected.* target_q_fraction <= 1'): + QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=1.1, start_step=0, end_step=None) + + with pytest.raises(ValueError, match='Expected.* 0 <= initial_q_fraction'): + QFractionLinearAnnealingConfig(initial_q_fraction=-0.1, target_q_fraction=-0.9, start_step=0, end_step=100) + + with pytest.raises(ValueError, match='Expected.* initial_q_fraction < target_q_fraction'): + QFractionLinearAnnealingConfig(initial_q_fraction=0.1, target_q_fraction=0.1, start_step=0, end_step=100) + + with pytest.raises(ValueError, match='Expected.* initial_q_fraction < target_q_fraction'): + QFractionLinearAnnealingConfig(initial_q_fraction=0.2, target_q_fraction=0.1, start_step=0, end_step=100) + + with pytest.raises(ValueError, match='Expected.* start_step >= 0'): + QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=-1, end_step=100) + + with pytest.raises(ValueError, match='Expected.* start_step < end_step'): + QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=100, end_step=100) + + with pytest.raises(ValueError, match='Expected.* start_step < end_step'): + QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=100, end_step=99) diff --git a/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py b/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py new file mode 100644 index 000000000..655c43faa --- /dev/null +++ b/tests_pytest/pytorch/gptq/test_gradual_act_quantization.py @@ -0,0 +1,100 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from unittest.mock import Mock + +import pytest +import torch + +from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device +from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler +from model_compression_toolkit.gptq import GradientPTQConfig, GradualActivationQuantizationConfig, QFractionLinearAnnealingConfig +from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import ( + GradualActivationQuantizerWrapper, get_gradual_activation_quantizer_wrapper_factory) + + +@pytest.fixture +def x(): + return torch.randn((2, 5, 6, 7), generator=torch.Generator().manual_seed(42)).to(device=get_working_device()) + + +class Quantizer: + def __call__(self, x, training): + self.training = training + return 3*x + 1 + + +class TestGradualActivationQuantization: + + def test_gradual_act_quant_wrapper(self, x): + quantizer = Quantizer() + qw = GradualActivationQuantizerWrapper(quantizer, q_fraction_scheduler=lambda t: t / (t + 1)) + + y0, y1, y2 = [qw(x) for _ in range(3)] + assert torch.equal(y0, x) # t=0 + assert torch.allclose(y1, 0.5 * x + (1.5 * x + 0.5)) # t=1 + assert torch.allclose(y2, x / 3 + (2 * x + 2 / 3)) # t=2 + assert quantizer.training is True + + _ = qw(x, False) + assert quantizer.training is False # correct flag was propagated + + def test_factory_no_qdrop(self): + quantizer_wrapper, quantizer = self._run_factory_test(qdrop_cfg=None, get_grad_steps_fn=None) + assert quantizer_wrapper is quantizer + + @pytest.mark.parametrize('end_step', (20, None)) + def test_factory_linear(self, x, end_step): + qdrop_cfg = GradualActivationQuantizationConfig( + QFractionLinearAnnealingConfig(initial_q_fraction=0.3, target_q_fraction=0.8, start_step=10, end_step=end_step) + ) + + def get_total_steps(): + if end_step is None: + return 50 + assert False # should not be called if end_step is passed + + quantizer_wrapper, quantizer = self._run_factory_test(qdrop_cfg, get_total_steps) + + scheduler = quantizer_wrapper.q_fraction_scheduler + assert isinstance(scheduler, LinearAnnealingScheduler) + exp_end_step = 50 if end_step is None else end_step + assert scheduler.t_start == 10 + assert scheduler.t_end == exp_end_step + assert scheduler.initial_val == 0.3 + assert scheduler.target_val == 0.8 + + y = [quantizer_wrapper(x) for _ in range(exp_end_step+1)] + assert torch.allclose(y[9], 0.7 * x + 0.3 * quantizer(x, True)) + assert torch.allclose(y[10], 0.7 * x + 0.3 * quantizer(x, True)) + assert torch.allclose(y[-1], 0.2 * x + 0.8 * quantizer(x, True)) + + def test_factory_linear_common_case(self, x): + # validate that we actually implemented the right thing - on first call float input, on last call fully quantized + qdrop_cfg = GradualActivationQuantizationConfig( + QFractionLinearAnnealingConfig(initial_q_fraction=0, target_q_fraction=1, start_step=0, end_step=None) + ) + quantizer_wrapper, quantizer = self._run_factory_test(qdrop_cfg, lambda: 15) + y0, *_, y_last = [quantizer_wrapper(x) for _ in range(16)] + assert torch.equal(y0, x) + assert torch.allclose(y_last, quantizer(x, True)) + + def _run_factory_test(self, qdrop_cfg, get_grad_steps_fn): + # Mocks are used to just pass anything + gptq_cfg = GradientPTQConfig(n_epochs=5, optimizer=Mock(), loss=Mock(), + gradual_activation_quantization_config=qdrop_cfg) + factory = get_gradual_activation_quantizer_wrapper_factory(gptq_cfg, get_grad_steps_fn) + quantizer = Quantizer() + quantizer_wrapper = factory(quantizer) + return quantizer_wrapper, quantizer diff --git a/tests_pytest/pytorch/trainable_infrastructure/__init__.py b/tests_pytest/pytorch/trainable_infrastructure/__init__.py new file mode 100644 index 000000000..e11a7cc60 --- /dev/null +++ b/tests_pytest/pytorch/trainable_infrastructure/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py b/tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py new file mode 100644 index 000000000..d6edca605 --- /dev/null +++ b/tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py @@ -0,0 +1,49 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch +import pytest + +from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import LinearAnnealingScheduler + + +def test_linear_annealing(): + scheduler = LinearAnnealingScheduler(t_start=10, t_end=35, initial_val=3.4, target_val=-1.6) + for t in [0, 9, 10]: + assert _isclose(scheduler(t), 3.4) + + for t in [35, 36, 1000]: + assert _isclose(scheduler(t), -1.6) + + assert _isclose(scheduler(11), 3.2) + assert _isclose(scheduler(27), 0.) + assert _isclose(scheduler(34), -1.4) + + +def test_linear_annealing_ascending(): + scheduler = LinearAnnealingScheduler(t_start=0, t_end=5, initial_val=-0.5, target_val=1.5) + assert _isclose(scheduler(0), -0.5) + assert _isclose(scheduler(1), -0.1) + assert _isclose(scheduler(4), 1.1) + assert _isclose(scheduler(5), 1.5) + + +@pytest.mark.parametrize('start', [5, -1]) +def test_invalid(start): + with pytest.raises(ValueError): + LinearAnnealingScheduler(t_start=start, t_end=4, initial_val=1, target_val=0) + + +def _isclose(x, y): + return torch.isclose(x, torch.tensor(y))