diff --git a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py index 5b746eed1..762f63ca8 100644 --- a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +++ b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py @@ -40,9 +40,9 @@ def set_bit_widths(mixed_precision_enable: bool, for n in graph.get_configurable_sorted_nodes(graph.fw_info)]), \ "All configurable nodes in graph should have at least one candidate configuration in mixed precision mode" - Logger.info(f'Set bit widths from configuration: {bit_widths_config}') # Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate). sorted_nodes_names = graph.get_configurable_sorted_nodes_names(graph.fw_info) + for node in graph.nodes: # set a specific node qc for each node final qc # If it's reused, take the configuration that the base node has node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2]) diff --git a/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py b/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py index 4c4f17bd8..9483d40f0 100644 --- a/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py +++ b/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from enum import Enum +from functools import partial import numpy as np @@ -21,6 +23,7 @@ def get_average_weights(distance_matrix: np.ndarray) -> np.ndarray: Get weights for weighting the sensitivity among different layers when evaluating MP configurations on model's sensitivity. This function returns equal weights for each layer, such that the sensitivity is averaged over all layers. + Args: distance_matrix: Numpy array at shape (L,M): L -number of interest points, M number of samples. The matrix contain the distance for each interest point at each sample. @@ -50,3 +53,21 @@ def get_last_layer_weights(distance_matrix: np.ndarray) -> np.ndarray: w = np.asarray([0 for _ in range(num_nodes)]) w[-1] = 1 return w + + +class MpDistanceWeighting(Enum): + """ + Defines mixed precision distance metric weighting methods. + The enum values can be used to call a function on a set of arguments and key-arguments. + + AVG - take the average distance on all computed layers. + + LAST_LAYER - take only the distance of the last layer output. + + """ + + AVG = partial(get_average_weights) + LAST_LAYER = partial(get_last_layer_weights) + + def __call__(self, distance_matrix: np.ndarray) -> np.ndarray: + return self.value(distance_matrix) diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py index 906c8b26e..60e9d5265 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py @@ -15,19 +15,18 @@ from typing import List, Callable -from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights -from model_compression_toolkit.core.common.similarity_analyzer import compute_mse +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting class MixedPrecisionQuantizationConfig: def __init__(self, compute_distance_fn: Callable = None, - distance_weighting_method: Callable = get_average_weights, + distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG, num_of_images: int = 32, configuration_overwrite: List[int] = None, num_interest_points_factor: float = 1.0, - use_hessian_based_scores: bool = True, + use_hessian_based_scores: bool = False, norm_scores: bool = True, refine_mp_solution: bool = True, metric_normalization_threshold: float = 1e10): @@ -35,8 +34,8 @@ def __init__(self, Class with mixed precision parameters to quantize the input model. Args: - compute_distance_fn (Callable): Function to compute a distance between two tensors. - distance_weighting_method (Callable): Function to use when weighting the distances among different layers when computing the sensitivity metric. + compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer. + distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric. num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model. configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one. num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric. @@ -63,8 +62,3 @@ def __init__(self, self.norm_scores = norm_scores self.metric_normalization_threshold = metric_normalization_threshold - - -# Default quantization configuration the library use. -DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse, - distance_weighting_method=get_average_weights) diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index 48352c152..7fde85628 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -232,7 +232,8 @@ def _add_set_of_kpi_constraints(search_manager: MixedPrecisionSearchManager, def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager, - target_kpi: KPI) -> Dict[int, Dict[int, float]]: + target_kpi: KPI, + eps: float = EPS) -> Dict[int, Dict[int, float]]: """ This function measures the sensitivity of a change in a bitwidth of a layer on the entire model. It builds a mapping from a node's index, to its bitwidht's effect on the model sensitivity. @@ -245,6 +246,7 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager, search_manager: MixedPrecisionSearchManager object to be used for problem formalization. target_kpi: KPI to constrain our LP problem with some resources limitations (like model' weights memory consumption). + eps: Epsilon value to manually increase metric value (if necessary) for numerical stability Returns: Mapping from each node's index in a graph, to a dictionary from the bitwidth index (of this node) to @@ -287,16 +289,18 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager, original_base_config=origin_max_config) origin_changed_nodes_indices = [i for i, c in enumerate(origin_max_config) if c != origin_mp_model_configuration[i]] - layer_to_metrics_mapping[node_idx][bitwidth_idx] = search_manager.compute_metric_fn( + metric_value = search_manager.compute_metric_fn( origin_mp_model_configuration, origin_changed_nodes_indices, origin_max_config) else: - layer_to_metrics_mapping[node_idx][bitwidth_idx] = search_manager.compute_metric_fn( + metric_value = search_manager.compute_metric_fn( mp_model_configuration, [node_idx], search_manager.max_kpi_config) + layer_to_metrics_mapping[node_idx][bitwidth_idx] = max(metric_value, max_config_value + eps) + # Finalize distance metric mapping search_manager.finalize_distance_metric(layer_to_metrics_mapping) diff --git a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py index bab966ea3..aaed945f6 100644 --- a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +++ b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py @@ -101,7 +101,10 @@ def greedy_solution_refinement_procedure(mp_solution: List[int], new_solution[node_idx_to_upgrade] = nodes_next_candidate[node_idx_to_upgrade] changed = True - Logger.info(f'Greedy MP algorithm changed configuration from: {mp_solution} to {new_solution}') + if any([mp_solution[i] != new_solution[i] for i in range(len(mp_solution))]): + Logger.info(f'Greedy MP algorithm changed configuration from (numbers represent indices of the ' + f'chosen bit-width candidate for each layer):\n{mp_solution}\nto\n{new_solution}') + return new_solution diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 7a8ad2110..22752e6a4 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -144,10 +144,13 @@ def core_runner(in_model: Any, weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info) activation_conf_nodes_bitwidth = tg.get_final_activation_config() - Logger.info( - f'Final weights bit-width configuration: {[node_b[1] for node_b in weights_conf_nodes_bitwidth]}') - Logger.info( - f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}') + if len(weights_conf_nodes_bitwidth) > 0: + Logger.info( + f'Final weights bit-width configuration: {[node_b[1] for node_b in weights_conf_nodes_bitwidth]}') + + if len(activation_conf_nodes_bitwidth) > 0: + Logger.info( + f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}') if tb_w is not None: finalize_bitwidth_in_tb(tb_w, weights_conf_nodes_bitwidth, activation_conf_nodes_bitwidth) diff --git a/tests/common_tests/base_test.py b/tests/common_tests/base_test.py index 3ed7b4517..5b08bc613 100644 --- a/tests/common_tests/base_test.py +++ b/tests/common_tests/base_test.py @@ -35,13 +35,13 @@ def get_input_shapes(self): def get_core_config(self): return CoreConfig(quantization_config=self.get_quantization_config(), - mixed_precision_config=self.get_mixed_precision_v2_config(), + mixed_precision_config=self.get_mixed_precision_config(), debug_config=self.get_debug_config()) def get_quantization_config(self): return QuantizationConfig() - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return None def get_debug_config(self): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py index 31bcfa3d7..010c52083 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_bops_test.py @@ -46,7 +46,7 @@ def get_tpc(self): mp_bitwidth_candidates_list=self.mixed_precision_candidates_list, name="mp_bopts_test") - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return MixedPrecisionQuantizationConfig(num_of_images=1) def get_input_shapes(self): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py index 71ce0c428..baccc1fdb 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/mixed_precision_tests.py @@ -64,7 +64,7 @@ def get_quantization_config(self): input_scaling=False, activation_channel_equalization=False) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) def get_input_shapes(self): @@ -422,7 +422,7 @@ def get_quantization_config(self): relu_bound_to_power_of_2=False, weights_bias_correction=True, input_scaling=False, activation_channel_equalization=False) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=self.num_of_inputs) def create_networks(self): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py index f75b1de69..c2b2db035 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/reused_layer_mixed_precision_test.py @@ -51,7 +51,7 @@ def get_quantization_config(self): relu_bound_to_power_of_2=True, weights_bias_correction=True, input_scaling=True, activation_channel_equalization=True) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return MixedPrecisionQuantizationConfig() def create_networks(self): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index ba1dcce30..9342ec087 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -18,7 +18,7 @@ import tensorflow as tf from model_compression_toolkit.defaultdict import DefaultDict -from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_last_layer_weights +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_KERNEL, BIAS_ATTR, BIAS from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs, generate_keras_tpc from tests.common_tests.helpers.generate_test_tp_model import generate_test_op_qc, generate_test_attr_configs @@ -37,14 +37,14 @@ class MixedPercisionBaseTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, val_batch_size=1): - super().__init__(unit_test, val_batch_size=val_batch_size ) + super().__init__(unit_test, val_batch_size=val_batch_size) def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=True, weights_bias_correction=True, input_scaling=True, activation_channel_equalization=True) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) def get_input_shapes(self): @@ -80,7 +80,7 @@ def get_quantization_config(self): relu_bound_to_power_of_2=True, weights_bias_correction=True, input_scaling=True, activation_channel_equalization=True) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig() def get_kpi(self): @@ -96,14 +96,20 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue(np.unique(conv_layers[1].weights[0]).flatten().shape[0] <= 8) -class MixedPercisionSearchTest(MixedPercisionBaseTest): - def __init__(self, unit_test): +class MixedPrecisionSearchTest(MixedPercisionBaseTest): + def __init__(self, unit_test, distance_metric=MpDistanceWeighting.AVG): super().__init__(unit_test, val_batch_size=2) + self.distance_metric = distance_metric + def get_kpi(self): # kpi is infinity -> should give best model - 8bits return KPI(np.inf) + def get_mixed_precision_config(self): + return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, + distance_weighting_method=self.distance_metric) + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) assert (quantization_info.mixed_precision_cfg == [0, @@ -224,7 +230,7 @@ class MixedPercisionCombinedNMSTest(MixedPercisionBaseTest): def __init__(self, unit_test): super().__init__(unit_test) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, use_hessian_based_scores=False) @@ -365,7 +371,7 @@ def get_quantization_config(self): relu_bound_to_power_of_2=False, weights_bias_correction=False, input_scaling=False, activation_channel_equalization=False) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig() @@ -381,7 +387,7 @@ def get_quantization_config(self): input_scaling=False, activation_channel_equalization=False) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1) def get_tpc(self): @@ -413,9 +419,9 @@ class MixedPercisionSearchLastLayerDistanceTest(MixedPercisionBaseTest): def __init__(self, unit_test): super().__init__(unit_test, val_batch_size=2) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, - distance_weighting_method=get_last_layer_weights, + distance_weighting_method=MpDistanceWeighting.LAST_LAYER, use_hessian_based_scores=False) def get_kpi(self): diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index c0544753e..6e43f5cd2 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -18,9 +18,11 @@ import numpy as np import tensorflow as tf +from sklearn.metrics.pairwise import distance_metrics from tensorflow.keras.layers import PReLU, ELU from model_compression_toolkit.core import QuantizationErrorMethod +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.gptq import RoundingType from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \ @@ -122,7 +124,7 @@ from tests.keras_tests.feature_networks_tests.feature_networks.uniform_range_selection_activation_test import \ UniformRangeSelectionActivationTest, UniformRangeSelectionBoundedActivationTest from tests.keras_tests.feature_networks_tests.feature_networks.weights_mixed_precision_tests import \ - MixedPercisionSearchTest, MixedPercisionDepthwiseTest, \ + MixedPrecisionSearchTest, MixedPercisionDepthwiseTest, \ MixedPercisionSearchKPI4BitsAvgTest, MixedPercisionSearchKPI2BitsAvgTest, MixedPrecisionActivationDisabled, \ MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationKPINonConfNodesTest, \ MixedPercisionSearchTotalKPINonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest, MixedPercisionCombinedNMSTest @@ -202,7 +204,8 @@ def test_mixed_precision_search_kpi_4bits_avg_nms(self): MixedPercisionCombinedNMSTest(self).run_test() def test_mixed_precision_search(self): - MixedPercisionSearchTest(self).run_test() + MixedPrecisionSearchTest(self, distance_metric=MpDistanceWeighting.AVG).run_test() + MixedPrecisionSearchTest(self, distance_metric=MpDistanceWeighting.LAST_LAYER).run_test() def test_mixed_precision_for_part_weights_layers(self): MixedPercisionSearchPartWeightsLayersTest(self).run_test() diff --git a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py index 732a1aa68..7e865c051 100644 --- a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py +++ b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py @@ -19,8 +19,6 @@ import model_compression_toolkit as mct from model_compression_toolkit.constants import FLOAT_BITWIDTH -from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ - DEFAULT_MIXEDPRECISION_CONFIG from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ set_quantization_configuration_to_graph @@ -30,7 +28,6 @@ from model_compression_toolkit.core.common.fusion.layer_fusing import fusion from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, generate_test_op_qc from tests.keras_tests.tpc_keras import get_tpc_with_activation_mp_keras -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs tp = mct.target_platform diff --git a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py index 7ab6e4f97..84252d4a0 100644 --- a/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py +++ b/tests/keras_tests/function_tests/test_sensitivity_metric_interest_points.py @@ -28,7 +28,7 @@ from keras.layers.core import TFOpLambda from model_compression_toolkit.constants import AXIS -from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ MixedPrecisionQuantizationConfig from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import get_mp_interest_points @@ -44,7 +44,7 @@ def build_ip_list_for_test(in_model, num_interest_points_factor): mp_qc = MixedPrecisionQuantizationConfig(compute_mse, - get_average_weights, + MpDistanceWeighting.AVG, num_of_images=1, num_interest_points_factor=num_interest_points_factor) fw_info = DEFAULT_KERAS_INFO diff --git a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py index 59ea34f98..bf0e2c18d 100644 --- a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py +++ b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py @@ -17,8 +17,7 @@ import keras from model_compression_toolkit.core import DEFAULTCONFIG -from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights, \ - get_last_layer_weights +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI, KPITarget from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ MixedPrecisionQuantizationConfig @@ -39,7 +38,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \ get_op_quantization_configs from tests.keras_tests.tpc_keras import get_weights_only_mp_tpc_keras - +from pulp import lpSum class MockReconstructionHelper: def __init__(self): @@ -56,15 +55,16 @@ class MockMixedPrecisionSearchManager: def __init__(self, layer_to_kpi_mapping): self.layer_to_bitwidth_mapping = {0: [0, 1, 2]} self.layer_to_kpi_mapping = layer_to_kpi_mapping - self.compute_metric_fn = lambda x, y=None, z=None: 0 + self.compute_metric_fn = lambda x, y=None, z=None: {0: 2, 1: 1, 2: 0}[x[0]] self.min_kpi = {KPITarget.WEIGHTS: [[1], [1], [1]], KPITarget.ACTIVATION: [[1], [1], [1]], KPITarget.TOTAL: [[2], [2], [2]], KPITarget.BOPS: [[1], [1], [1]]} # minimal kpi in the tests layer_to_kpi_mapping - self.compute_kpi_functions = {KPITarget.WEIGHTS: (None, lambda v: [sum(v)]), + + self.compute_kpi_functions = {KPITarget.WEIGHTS: (None, lambda v: [lpSum(v)]), KPITarget.ACTIVATION: (None, lambda v: [i for i in v]), - KPITarget.TOTAL: (None, lambda v: [sum(v[0]) + i for i in v[1]]), - KPITarget.BOPS: (None, lambda v: [sum(v)])} + KPITarget.TOTAL: (None, lambda v: [lpSum(v[0]) + i for i in v[1]]), + KPITarget.BOPS: (None, lambda v: [lpSum(v)])} self.max_kpi_config = [0] self.config_reconstruction_helper = MockReconstructionHelper() self.non_conf_kpi_dict = None @@ -76,8 +76,8 @@ def compute_kpi_matrix(self, target): elif target == KPITarget.ACTIVATION: kpi_matrix = [np.flip(np.array([kpi.activation_memory - 1 for _, kpi in self.layer_to_kpi_mapping[0].items()]))] elif target == KPITarget.TOTAL: - kpi_matrix = [np.flip(np.array([kpi.total_memory - 1 for _, kpi in self.layer_to_kpi_mapping[0].items()])), - np.flip(np.array([kpi.total_memory - 1 for _, kpi in self.layer_to_kpi_mapping[0].items()]))] + kpi_matrix = [np.flip(np.array([kpi.weights_memory - 1 for _, kpi in self.layer_to_kpi_mapping[0].items()])), + np.flip(np.array([kpi.activation_memory - 1 for _, kpi in self.layer_to_kpi_mapping[0].items()]))] elif target == KPITarget.BOPS: kpi_matrix = [np.flip(np.array([kpi.bops - 1 for _, kpi in self.layer_to_kpi_mapping[0].items()]))] else: @@ -114,7 +114,7 @@ def test_search_weights_only(self): target_kpi=KPI(weights_memory=np.inf)) self.assertTrue(len(bit_cfg) == 1) - self.assertTrue(bit_cfg[0] == 2) + self.assertTrue(bit_cfg[0] == 0) # KPI is Inf so expecting for the maximal bit-width result target_kpi = None # target KPI is not defined! with self.assertRaises(Exception): @@ -143,7 +143,7 @@ def test_search_activation_only(self): target_kpi=KPI(activation_memory=np.inf)) self.assertTrue(len(bit_cfg) == 1) - self.assertTrue(bit_cfg[0] == 2) + self.assertTrue(bit_cfg[0] == 0) # KPI is Inf so expecting for the maximal bit-width result def test_search_weights_and_activation(self): target_kpi = KPI(weights_memory=2, activation_memory=2) @@ -167,13 +167,13 @@ def test_search_weights_and_activation(self): target_kpi=KPI(weights_memory=np.inf, activation_memory=np.inf)) self.assertTrue(len(bit_cfg) == 1) - self.assertTrue(bit_cfg[0] == 2) + self.assertTrue(bit_cfg[0] == 0) # KPI is Inf so expecting for the maximal bit-width result def test_search_total_kpi(self): - target_kpi = KPI(total_memory=2) - layer_to_kpi_mapping = {0: {2: KPI(total_memory=1), - 1: KPI(total_memory=2), - 0: KPI(total_memory=3)}} + target_kpi = KPI(total_memory=4) + layer_to_kpi_mapping = {0: {2: KPI(weights_memory=1, activation_memory=1), + 1: KPI(weights_memory=2, activation_memory=2), + 0: KPI(weights_memory=3, activation_memory=3)}} mock_search_manager = MockMixedPrecisionSearchManager(layer_to_kpi_mapping) bit_cfg = mp_integer_programming_search(mock_search_manager, @@ -280,7 +280,7 @@ def representative_data_gen(): def test_mixed_precision_search_facade(self): core_config_avg_weights = CoreConfig(quantization_config=DEFAULTCONFIG, mixed_precision_config=MixedPrecisionQuantizationConfig(compute_mse, - get_average_weights, + MpDistanceWeighting.AVG, num_of_images=1, use_hessian_based_scores=False)) @@ -288,7 +288,7 @@ def test_mixed_precision_search_facade(self): core_config_last_layer = CoreConfig(quantization_config=DEFAULTCONFIG, mixed_precision_config=MixedPrecisionQuantizationConfig(compute_mse, - get_last_layer_weights, + MpDistanceWeighting.LAST_LAYER, num_of_images=1, use_hessian_based_scores=False)) diff --git a/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py b/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py index 8b54d870f..9a4f132d3 100644 --- a/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py +++ b/tests/keras_tests/non_parallel_tests/test_tensorboard_writer.py @@ -25,8 +25,6 @@ import model_compression_toolkit as mct from model_compression_toolkit.constants import TENSORFLOW -from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ - DEFAULT_MIXEDPRECISION_CONFIG from model_compression_toolkit.core.common.visualization.final_config_visualizer import \ ActivationFinalBitwidthConfigVisualizer from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL @@ -37,7 +35,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs from tests.common_tests.helpers.generate_test_tp_model import generate_tp_model_with_activation_mp from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_set_bit_widths -from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.core.common.similarity_analyzer import compute_mse keras = tf.keras @@ -111,7 +109,7 @@ def plot_tensor_sizes(self): # Hessian service assumes core should be initialized. This test does not do it, so we disable the use of hessians in MP cfg = mct.core.DEFAULTCONFIG mp_cfg = mct.core.MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse, - distance_weighting_method=get_average_weights, + distance_weighting_method=MpDistanceWeighting.AVG, use_hessian_based_scores=False) # compare max tensor size with plotted max tensor size diff --git a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py index b3fd32fb7..2854cb94c 100644 --- a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py +++ b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py @@ -23,13 +23,13 @@ from torchvision.models import mobilenet_v2 import model_compression_toolkit as mct +from model_compression_toolkit.core import MixedPrecisionQuantizationConfig from model_compression_toolkit.defaultdict import DefaultDict from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import LayerFilterParams from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import Greater, Smaller, Eq -from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import DEFAULT_MIXEDPRECISION_CONFIG from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \ TFLITE_TP_MODEL, QNNPACK_TP_MODEL, KERNEL_ATTR, WEIGHTS_N_BITS, PYTORCH_KERNEL, BIAS_ATTR, BIAS from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation @@ -240,7 +240,7 @@ def rep_data(): rep_data, target_platform_capabilities=tpc) - mp_qc = copy.deepcopy(DEFAULT_MIXEDPRECISION_CONFIG) + mp_qc = MixedPrecisionQuantizationConfig() mp_qc.num_of_images = 1 core_config = mct.core.CoreConfig(quantization_config=mct.core.QuantizationConfig(), mixed_precision_config=mp_qc) diff --git a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py index 7c6c98208..41e76a7d2 100644 --- a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_activation_test.py @@ -126,7 +126,7 @@ def __init__(self, unit_test): def get_kpi(self): return KPI(np.inf, np.inf) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return MixedPrecisionQuantizationConfig(num_of_images=4) def create_feature_network(self, input_shape): diff --git a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py index 0de9caade..0e6ebf2d6 100644 --- a/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/mixed_precision_weights_test.py @@ -18,7 +18,7 @@ from model_compression_toolkit.defaultdict import DefaultDict from model_compression_toolkit.core import KPI -from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_last_layer_weights +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.core.common.user_info import UserInformation from model_compression_toolkit.core.pytorch.constants import BIAS from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS_ATTR @@ -79,12 +79,23 @@ def compare_results(self, quantization_info, quantized_models, float_model, expe class MixedPercisionSearch8Bit(MixedPercisionBaseTest): - def __init__(self, unit_test): + def __init__(self, unit_test, distance_metric=MpDistanceWeighting.AVG): super().__init__(unit_test) + self.distance_metric = distance_metric + def get_kpi(self): return KPI(np.inf) + def get_core_configs(self): + qc = mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE, + relu_bound_to_power_of_2=False, weights_bias_correction=True, + input_scaling=False, activation_channel_equalization=False) + mpc = mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, + distance_weighting_method=self.distance_metric) + + return {"mixed_precision_model": mct.core.CoreConfig(quantization_config=qc, mixed_precision_config=mpc)} + def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): self.compare_results(quantization_info, quantized_models, float_model, 0) @@ -217,10 +228,10 @@ def __init__(self, unit_test): def get_kpi(self): return KPI(192) - def get_mixed_precision_v2_config(self): + def get_mixed_precision_config(self): return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, use_hessian_based_scores=False, - distance_weighting_method=get_last_layer_weights) + distance_weighting_method=MpDistanceWeighting.LAST_LAYER) def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): self.compare_results(quantization_info, quantized_models, float_model, 1) diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 2b3ce5795..1e6a6b71b 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -19,6 +19,7 @@ import torch from torch import nn import model_compression_toolkit as mct +from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.gptq.common.gptq_config import RoundingType from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest from tests.pytorch_tests.model_tests.feature_models.bn_attributes_quantization_test import BNAttributesQuantization @@ -394,7 +395,8 @@ def test_mixed_precision_8bit(self): """ This test checks the Mixed Precision search. """ - MixedPercisionSearch8Bit(self).run_test() + MixedPercisionSearch8Bit(self, distance_metric=MpDistanceWeighting.AVG).run_test() + MixedPercisionSearch8Bit(self, distance_metric=MpDistanceWeighting.LAST_LAYER).run_test() def test_mixed_precision_part_weights_layers(self): """