Skip to content

Commit

Permalink
Refactor MP config and kpi_data API (#963)
Browse files Browse the repository at this point in the history
* Remove old MP config and V2 from current MP config class name.
* Add default CoreConfig value in kpi_data API methods (both pytorch and keras).

---------

Co-authored-by: reuvenp <reuvenp@altair-semi.com>
  • Loading branch information
reuvenperetz and reuvenp authored Mar 4, 2024
1 parent 40ee30d commit 04a9db1
Show file tree
Hide file tree
Showing 51 changed files with 149 additions and 223 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, QuantizationErrorMethod, DEFAULTCONFIG
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig, MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
Expand Down Expand Up @@ -304,7 +304,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
@abstractmethod
def get_sensitivity_evaluator(self,
graph: Graph,
quant_config: MixedPrecisionQuantizationConfigV2,
quant_config: MixedPrecisionQuantizationConfig,
representative_data_gen: Callable,
fw_info: FrameworkInfo,
hessian_info_service: HessianInfoService = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
# limitations under the License.
# ==============================================================================

from enum import Enum
from typing import List, Callable, Tuple
from typing import List, Callable

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, DEFAULTCONFIG
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse


class MixedPrecisionQuantizationConfigV2:
class MixedPrecisionQuantizationConfig:

def __init__(self,
compute_distance_fn: Callable = None,
Expand All @@ -36,8 +33,6 @@ def __init__(self,
metric_normalization_threshold: float = 1e10):
"""
Class with mixed precision parameters to quantize the input model.
Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
support mixed-precision model quantization.
Args:
compute_distance_fn (Callable): Function to compute a distance between two tensors.
Expand Down Expand Up @@ -70,67 +65,6 @@ def __init__(self,
self.metric_normalization_threshold = metric_normalization_threshold


class MixedPrecisionQuantizationConfig(QuantizationConfig):

def __init__(self,
qc: QuantizationConfig = DEFAULTCONFIG,
compute_distance_fn: Callable = compute_mse,
distance_weighting_method: Callable = get_average_weights,
num_of_images: int = 32,
configuration_overwrite: List[int] = None,
num_interest_points_factor: float = 1.0):
"""
Class to wrap all different parameters the library quantize the input model according to.
Unlike QuantizationConfig, number of bits for quantization is a list of possible bit widths to
support mixed-precision model quantization.
Args:
qc (QuantizationConfig): QuantizationConfig object containing parameters of how the model should be quantized.
compute_distance_fn (Callable): Function to compute a distance between two tensors.
distance_weighting_method (Callable): Function to use when weighting the distances among different layers when computing the sensitivity metric.
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
num_interest_points_factor: A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
"""

super().__init__(**qc.__dict__)
self.compute_distance_fn = compute_distance_fn
self.distance_weighting_method = distance_weighting_method
self.num_of_images = num_of_images
self.configuration_overwrite = configuration_overwrite

assert 0.0 < num_interest_points_factor <= 1.0, "num_interest_points_factor should represent a percentage of " \
"the base set of interest points that are required to be " \
"used for mixed-precision metric evaluation, " \
"thus, it should be between 0 to 1"
self.num_interest_points_factor = num_interest_points_factor

def separate_configs(self) -> Tuple[QuantizationConfig, MixedPrecisionQuantizationConfigV2]:
"""
A function to separate the old MixedPrecisionQuantizationConfig into QuantizationConfig
and MixedPrecisionQuantizationConfigV2
Returns: QuantizationConfig, MixedPrecisionQuantizationConfigV2
"""
_dummy_quant_config = QuantizationConfig()
_dummy_mp_config_experimental = MixedPrecisionQuantizationConfigV2()
qc_dict = {}
mp_dict = {}
for k, v in self.__dict__.items():
if hasattr(_dummy_quant_config, k):
qc_dict.update({k: v})
elif hasattr(_dummy_mp_config_experimental, k):
mp_dict.update({k: v})
else:
Logger.error(f'Attribute "{k}" mismatch: exists in MixedPrecisionQuantizationConfig but not in '
f'MixedPrecisionQuantizationConfigV2') # pragma: no cover

return QuantizationConfig(**qc_dict), MixedPrecisionQuantizationConfigV2(**mp_dict)


# Default quantization configuration the library use.
DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(DEFAULTCONFIG,
compute_mse,
get_average_weights)
DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse,
distance_weighting_method=get_average_weights)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
from typing import List, Callable, Dict

from model_compression_toolkit.core import MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget
Expand Down Expand Up @@ -48,7 +48,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
target_kpi: KPI,
mp_config: MixedPrecisionQuantizationConfigV2,
mp_config: MixedPrecisionQuantizationConfig,
representative_data_gen: Callable,
search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING,
hessian_info_service: HessianInfoService=None) -> List[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Callable, Any, List, Tuple

from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode

Expand All @@ -37,7 +37,7 @@ class SensitivityEvaluation:

def __init__(self,
graph: Graph,
quant_config: MixedPrecisionQuantizationConfigV2,
quant_config: MixedPrecisionQuantizationConfig,
representative_data_gen: Callable,
fw_info: FrameworkInfo,
fw_impl: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig


class CoreConfig:
Expand All @@ -23,14 +23,14 @@ class CoreConfig:
"""
def __init__(self,
quantization_config: QuantizationConfig = QuantizationConfig(),
mixed_precision_config: MixedPrecisionQuantizationConfigV2 = None,
mixed_precision_config: MixedPrecisionQuantizationConfig = None,
debug_config: DebugConfig = DebugConfig()
):
"""
Args:
quantization_config (QuantizationConfig): Config for quantization.
mixed_precision_config (MixedPrecisionQuantizationConfigV2): Config for mixed precision quantization (optional, default=None).
mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization (optional, default=None).
debug_config (DebugConfig): Config for debugging and editing the network quantization process.
"""
self.quantization_config = quantization_config
Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
Concatenate, Add
from keras.layers.core import TFOpLambda

from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
Expand Down Expand Up @@ -355,7 +355,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz

def get_sensitivity_evaluator(self,
graph: Graph,
quant_config: MixedPrecisionQuantizationConfigV2,
quant_config: MixedPrecisionQuantizationConfig,
representative_data_gen: Callable,
fw_info: FrameworkInfo,
disable_activation_for_metric: bool = False,
Expand Down
11 changes: 5 additions & 6 deletions model_compression_toolkit/core/keras/kpi_data_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# ==============================================================================

from typing import Callable

from model_compression_toolkit.core import CoreConfig, MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import TENSORFLOW
Expand All @@ -36,7 +35,7 @@

def keras_kpi_data(in_model: Model,
representative_data_gen: Callable,
core_config: CoreConfig,
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> KPI:
"""
Expand Down Expand Up @@ -73,9 +72,9 @@ def keras_kpi_data(in_model: Model,
"""

if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfigV2 object."
"Given quant_config is not of type MixedPrecisionQuantizationConfigV2.")
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfig object."
"Given quant_config is not of type MixedPrecisionQuantizationConfig.")

fw_impl = KerasImplementation()

Expand Down
10 changes: 5 additions & 5 deletions model_compression_toolkit/core/pytorch/kpi_data_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.constants import FOUND_TORCH

if FOUND_TORCH:
Expand All @@ -38,7 +38,7 @@

def pytorch_kpi_data(in_model: Module,
representative_data_gen: Callable,
core_config: CoreConfig = CoreConfig(), # TODO: Why pytorch is initilized and keras not?
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI:
"""
Expand Down Expand Up @@ -75,9 +75,9 @@ def pytorch_kpi_data(in_model: Module,
"""

if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfigV2 object."
"Given quant_config is not of type MixedPrecisionQuantizationConfigV2.")
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfig object."
"Given quant_config is not of type MixedPrecisionQuantizationConfig.")

fw_impl = PytorchImplementation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import model_compression_toolkit.core.pytorch.constants as pytorch_constants
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
Expand Down Expand Up @@ -332,7 +332,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz

def get_sensitivity_evaluator(self,
graph: Graph,
quant_config: MixedPrecisionQuantizationConfigV2,
quant_config: MixedPrecisionQuantizationConfig,
representative_data_gen: Callable,
fw_info: FrameworkInfo,
disable_activation_for_metric: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions model_compression_toolkit/gptq/keras/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.runner import core_runner
from model_compression_toolkit.gptq.runner import gptq_runner
Expand Down Expand Up @@ -177,7 +177,7 @@ def keras_gradient_post_training_quantization(in_model: Model,
with different bitwidths for different layers.
The candidates bitwidth for quantization should be defined in the target platform model:
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfigV2(num_of_images=1))
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
For mixed-precision set a target KPI object:
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
Expand All @@ -199,9 +199,9 @@ def keras_gradient_post_training_quantization(in_model: Model,
fw_info=fw_info).validate()

if core_config.mixed_precision_enable:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.error("Given quantization config to mixed-precision facade is not of type "
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
"API, or pass a valid mixed precision configuration.") # pragma: no cover

tb_w = init_tensorboard_writer(fw_info)
Expand Down
6 changes: 3 additions & 3 deletions model_compression_toolkit/gptq/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfigV2
MixedPrecisionQuantizationConfig

LR_DEFAULT = 1e-4
LR_REST_DEFAULT = 1e-4
Expand Down Expand Up @@ -157,9 +157,9 @@ def pytorch_gradient_post_training_quantization(model: Module,
"""

if core_config.mixed_precision_enable:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.error("Given quantization config to mixed-precision facade is not of type "
"MixedPrecisionQuantizationConfigV2. Please use keras_post_training_quantization "
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
"API, or pass a valid mixed precision configuration.") # pragma: no cover

tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
Expand Down
Loading

0 comments on commit 04a9db1

Please sign in to comment.