Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Mixed precision experience #985

Merged
merged 13 commits into from
Mar 11, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def set_bit_widths(mixed_precision_enable: bool,
for n in graph.get_configurable_sorted_nodes(graph.fw_info)]), \
"All configurable nodes in graph should have at least one candidate configuration in mixed precision mode"

Logger.info(f'Set bit widths from configuration: {bit_widths_config}')
# Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate).
sorted_nodes_names = graph.get_configurable_sorted_nodes_names(graph.fw_info)

for node in graph.nodes: # set a specific node qc for each node final qc
# If it's reused, take the configuration that the base node has
node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from enum import Enum
from functools import partial

import numpy as np

Expand All @@ -21,6 +23,7 @@ def get_average_weights(distance_matrix: np.ndarray) -> np.ndarray:
Get weights for weighting the sensitivity among different layers when evaluating MP configurations on
model's sensitivity. This function returns equal weights for each layer, such that the sensitivity
is averaged over all layers.

Args:
distance_matrix: Numpy array at shape (L,M): L -number of interest points, M number of samples.
The matrix contain the distance for each interest point at each sample.
Expand Down Expand Up @@ -50,3 +53,21 @@ def get_last_layer_weights(distance_matrix: np.ndarray) -> np.ndarray:
w = np.asarray([0 for _ in range(num_nodes)])
w[-1] = 1
return w


class MpDistanceWeighting(Enum):
"""
Defines mixed precision distance metric weighting methods.
The enum values can be used to call a function on a set of arguments and key-arguments.

AVG - take the average distance on all computed layers.

LAST_LAYER - take only the distance of the last layer output.

"""

AVG = partial(get_average_weights)
LAST_LAYER = partial(get_last_layer_weights)

def __call__(self, distance_matrix: np.ndarray) -> np.ndarray:
return self.value(distance_matrix)
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,27 @@

from typing import List, Callable

from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting


class MixedPrecisionQuantizationConfig:

def __init__(self,
compute_distance_fn: Callable = None,
distance_weighting_method: Callable = get_average_weights,
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG,
num_of_images: int = 32,
configuration_overwrite: List[int] = None,
num_interest_points_factor: float = 1.0,
use_hessian_based_scores: bool = True,
use_hessian_based_scores: bool = False,
norm_scores: bool = True,
refine_mp_solution: bool = True,
metric_normalization_threshold: float = 1e10):
"""
Class with mixed precision parameters to quantize the input model.

Args:
compute_distance_fn (Callable): Function to compute a distance between two tensors.
distance_weighting_method (Callable): Function to use when weighting the distances among different layers when computing the sensitivity metric.
compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
Expand All @@ -63,8 +62,3 @@ def __init__(self,
self.norm_scores = norm_scores

self.metric_normalization_threshold = metric_normalization_threshold


# Default quantization configuration the library use.
DEFAULT_MIXEDPRECISION_CONFIG = MixedPrecisionQuantizationConfig(compute_distance_fn=compute_mse,
distance_weighting_method=get_average_weights)
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def _add_set_of_kpi_constraints(search_manager: MixedPrecisionSearchManager,


def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
target_kpi: KPI) -> Dict[int, Dict[int, float]]:
target_kpi: KPI,
eps: float = 1e-8) -> Dict[int, Dict[int, float]]:
ofirgo marked this conversation as resolved.
Show resolved Hide resolved
"""
This function measures the sensitivity of a change in a bitwidth of a layer on the entire model.
It builds a mapping from a node's index, to its bitwidht's effect on the model sensitivity.
Expand All @@ -245,6 +246,7 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
search_manager: MixedPrecisionSearchManager object to be used for problem formalization.
target_kpi: KPI to constrain our LP problem with some resources limitations (like model' weights memory
consumption).
eps: Epsilon value to manually increase metric value (if necessary) for numerical stability

Returns:
Mapping from each node's index in a graph, to a dictionary from the bitwidth index (of this node) to
Expand Down Expand Up @@ -287,16 +289,18 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
original_base_config=origin_max_config)
origin_changed_nodes_indices = [i for i, c in enumerate(origin_max_config) if
c != origin_mp_model_configuration[i]]
layer_to_metrics_mapping[node_idx][bitwidth_idx] = search_manager.compute_metric_fn(
metric_value = search_manager.compute_metric_fn(
origin_mp_model_configuration,
origin_changed_nodes_indices,
origin_max_config)
else:
layer_to_metrics_mapping[node_idx][bitwidth_idx] = search_manager.compute_metric_fn(
metric_value = search_manager.compute_metric_fn(
mp_model_configuration,
[node_idx],
search_manager.max_kpi_config)

layer_to_metrics_mapping[node_idx][bitwidth_idx] = max(metric_value, max_config_value + eps)
ofirgo marked this conversation as resolved.
Show resolved Hide resolved

# Finalize distance metric mapping
search_manager.finalize_distance_metric(layer_to_metrics_mapping)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
new_solution[node_idx_to_upgrade] = nodes_next_candidate[node_idx_to_upgrade]
changed = True

Logger.info(f'Greedy MP algorithm changed configuration from: {mp_solution} to {new_solution}')
if any([mp_solution[i] != new_solution[i] for i in range(len(mp_solution))]):
Logger.info(f'Greedy MP algorithm changed configuration from (numbers represent indices of the '
f'chosen bit-width candidate for each layer):\n{mp_solution}\nto\n{new_solution}')

return new_solution


Expand Down
11 changes: 7 additions & 4 deletions model_compression_toolkit/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ def core_runner(in_model: Any,
weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info)
activation_conf_nodes_bitwidth = tg.get_final_activation_config()

Logger.info(
f'Final weights bit-width configuration: {[node_b[1] for node_b in weights_conf_nodes_bitwidth]}')
Logger.info(
f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}')
if len(weights_conf_nodes_bitwidth) > 0:
Logger.info(
f'Final weights bit-width configuration: {[node_b[1] for node_b in weights_conf_nodes_bitwidth]}')

if len(activation_conf_nodes_bitwidth) > 0:
Logger.info(
f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}')

if tb_w is not None:
finalize_bitwidth_in_tb(tb_w, weights_conf_nodes_bitwidth, activation_conf_nodes_bitwidth)
Expand Down
4 changes: 2 additions & 2 deletions tests/common_tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def get_input_shapes(self):

def get_core_config(self):
return CoreConfig(quantization_config=self.get_quantization_config(),
mixed_precision_config=self.get_mixed_precision_v2_config(),
mixed_precision_config=self.get_mixed_precision_config(),
debug_config=self.get_debug_config())

def get_quantization_config(self):
return QuantizationConfig()

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return None

def get_debug_config(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_tpc(self):
mp_bitwidth_candidates_list=self.mixed_precision_candidates_list,
name="mp_bopts_test")

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return MixedPrecisionQuantizationConfig(num_of_images=1)

def get_input_shapes(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_quantization_config(self):
input_scaling=False,
activation_channel_equalization=False)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1)

def get_input_shapes(self):
Expand Down Expand Up @@ -422,7 +422,7 @@ def get_quantization_config(self):
relu_bound_to_power_of_2=False, weights_bias_correction=True,
input_scaling=False, activation_channel_equalization=False)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=self.num_of_inputs)

def create_networks(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_quantization_config(self):
relu_bound_to_power_of_2=True, weights_bias_correction=True,
input_scaling=True, activation_channel_equalization=True)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return MixedPrecisionQuantizationConfig()

def create_networks(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tensorflow as tf

from model_compression_toolkit.defaultdict import DefaultDict
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_last_layer_weights
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_KERNEL, BIAS_ATTR, BIAS
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs, generate_keras_tpc
from tests.common_tests.helpers.generate_test_tp_model import generate_test_op_qc, generate_test_attr_configs
Expand All @@ -37,14 +37,14 @@

class MixedPercisionBaseTest(BaseKerasFeatureNetworkTest):
def __init__(self, unit_test, val_batch_size=1):
super().__init__(unit_test, val_batch_size=val_batch_size )
super().__init__(unit_test, val_batch_size=val_batch_size)

def get_quantization_config(self):
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE,
relu_bound_to_power_of_2=True, weights_bias_correction=True,
input_scaling=True, activation_channel_equalization=True)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1)

def get_input_shapes(self):
Expand Down Expand Up @@ -80,7 +80,7 @@ def get_quantization_config(self):
relu_bound_to_power_of_2=True, weights_bias_correction=True,
input_scaling=True, activation_channel_equalization=True)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig()

def get_kpi(self):
Expand All @@ -96,14 +96,20 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
self.unit_test.assertTrue(np.unique(conv_layers[1].weights[0]).flatten().shape[0] <= 8)


class MixedPercisionSearchTest(MixedPercisionBaseTest):
def __init__(self, unit_test):
class MixedPrecisionSearchTest(MixedPercisionBaseTest):
def __init__(self, unit_test, distance_metric=MpDistanceWeighting.AVG):
super().__init__(unit_test, val_batch_size=2)

self.distance_metric = distance_metric

def get_kpi(self):
# kpi is infinity -> should give best model - 8bits
return KPI(np.inf)

def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1,
distance_weighting_method=self.distance_metric)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D)
assert (quantization_info.mixed_precision_cfg == [0,
Expand Down Expand Up @@ -224,7 +230,7 @@ class MixedPercisionCombinedNMSTest(MixedPercisionBaseTest):
def __init__(self, unit_test):
super().__init__(unit_test)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1,
use_hessian_based_scores=False)

Expand Down Expand Up @@ -365,7 +371,7 @@ def get_quantization_config(self):
relu_bound_to_power_of_2=False, weights_bias_correction=False,
input_scaling=False, activation_channel_equalization=False)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig()


Expand All @@ -381,7 +387,7 @@ def get_quantization_config(self):
input_scaling=False,
activation_channel_equalization=False)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1)

def get_tpc(self):
Expand Down Expand Up @@ -413,9 +419,9 @@ class MixedPercisionSearchLastLayerDistanceTest(MixedPercisionBaseTest):
def __init__(self, unit_test):
super().__init__(unit_test, val_batch_size=2)

def get_mixed_precision_v2_config(self):
def get_mixed_precision_config(self):
return mct.core.MixedPrecisionQuantizationConfig(num_of_images=1,
distance_weighting_method=get_last_layer_weights,
distance_weighting_method=MpDistanceWeighting.LAST_LAYER,
use_hessian_based_scores=False)

def get_kpi(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import numpy as np
import tensorflow as tf
from sklearn.metrics.pairwise import distance_metrics
from tensorflow.keras.layers import PReLU, ELU

from model_compression_toolkit.core import QuantizationErrorMethod
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.gptq import RoundingType
from tests.keras_tests.feature_networks_tests.feature_networks.activation_decomposition_test import \
Expand Down Expand Up @@ -122,7 +124,7 @@
from tests.keras_tests.feature_networks_tests.feature_networks.uniform_range_selection_activation_test import \
UniformRangeSelectionActivationTest, UniformRangeSelectionBoundedActivationTest
from tests.keras_tests.feature_networks_tests.feature_networks.weights_mixed_precision_tests import \
MixedPercisionSearchTest, MixedPercisionDepthwiseTest, \
MixedPrecisionSearchTest, MixedPercisionDepthwiseTest, \
MixedPercisionSearchKPI4BitsAvgTest, MixedPercisionSearchKPI2BitsAvgTest, MixedPrecisionActivationDisabled, \
MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationKPINonConfNodesTest, \
MixedPercisionSearchTotalKPINonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest, MixedPercisionCombinedNMSTest
Expand Down Expand Up @@ -202,7 +204,8 @@ def test_mixed_precision_search_kpi_4bits_avg_nms(self):
MixedPercisionCombinedNMSTest(self).run_test()

def test_mixed_precision_search(self):
MixedPercisionSearchTest(self).run_test()
MixedPrecisionSearchTest(self, distance_metric=MpDistanceWeighting.AVG).run_test()
MixedPrecisionSearchTest(self, distance_metric=MpDistanceWeighting.LAST_LAYER).run_test()

def test_mixed_precision_for_part_weights_layers(self):
MixedPercisionSearchPartWeightsLayersTest(self).run_test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

import model_compression_toolkit as mct
from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
DEFAULT_MIXEDPRECISION_CONFIG
from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \
set_quantization_configuration_to_graph
Expand All @@ -30,7 +28,6 @@
from model_compression_toolkit.core.common.fusion.layer_fusing import fusion
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, generate_test_op_qc
from tests.keras_tests.tpc_keras import get_tpc_with_activation_mp_keras
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs

tp = mct.target_platform

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from keras.layers.core import TFOpLambda

from model_compression_toolkit.constants import AXIS
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import get_average_weights
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import get_mp_interest_points
Expand All @@ -44,7 +44,7 @@

def build_ip_list_for_test(in_model, num_interest_points_factor):
mp_qc = MixedPrecisionQuantizationConfig(compute_mse,
get_average_weights,
MpDistanceWeighting.AVG,
num_of_images=1,
num_interest_points_factor=num_interest_points_factor)
fw_info = DEFAULT_KERAS_INFO
Expand Down
Loading
Loading