diff --git a/model_compression_toolkit/constants.py b/model_compression_toolkit/constants.py index c90092009..a6f0fca63 100644 --- a/model_compression_toolkit/constants.py +++ b/model_compression_toolkit/constants.py @@ -69,7 +69,8 @@ # that are shared among different candidates: WEIGHTS_NBITS_ATTRIBUTE = 'weights_n_bits' CORRECTED_BIAS_ATTRIBUTE = 'corrected_bias' -ACTIVATION_NBITS_ATTRIBUTE = 'activation_n_bits' +ACTIVATION_N_BITS_ATTRIBUTE = 'activation_n_bits' +SUPPORTED_INPUT_ACTIVATION_NBITS_ATTRIBUTE = 'supported_input_activation_n_bits' # Quantization Parameters Iterative Search Defaults: SYMMETRIC_TENSOR_N_ITER = 40 diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index 59d6c8d7b..9833e20cd 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -19,11 +19,11 @@ import numpy as np from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \ - ACTIVATION_NBITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER + ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \ - TargetPlatformCapabilities, LayerFilterParams + TargetPlatformCapabilities, LayerFilterParams, OpQuantizationConfig class BaseNode: @@ -297,7 +297,6 @@ def get_memory_bytes(self, fw_info) -> float: return memory - def get_unified_weights_candidates_dict(self, fw_info) -> Dict[str, Any]: """ In Mixed-Precision, a node's kernel can have multiple candidates for weights quantization configuration. @@ -343,7 +342,7 @@ def get_unified_activation_candidates_dict(self) -> Dict[str, Any]: Returns: A dictionary containing information from node's activation quantization configuration candidates. """ - shared_attributes = [ACTIVATION_NBITS_ATTRIBUTE] + shared_attributes = [ACTIVATION_N_BITS_ATTRIBUTE] attr = dict() if self.is_activation_quantization_enabled(): attr = copy.deepcopy(self.candidates_quantization_cfg[0].activation_quantization_cfg.__dict__) @@ -539,7 +538,7 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions: to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformModel. Args: - tpc: TPC to extract the QuantizationConfigOptions for the node + tpc: TPC to extract the QuantizationConfigOptions for the node. Returns: QuantizationConfigOptions of the node. @@ -559,6 +558,52 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions: return matching_qcos[0] return tpc.tp_model.default_qco + def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities, + next_nodes: List, node_qc_options: QuantizationConfigOptions + ) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: + """ + Filter quantization config options that don't match the graph. + A node may have several quantization config options with 'activation_n_bits' values, and + the next nodes in the graph may support different bit-width as input activation. This function + filters out quantization config that don't comply to these attributes. + + Args: + tpc: TPC to extract the QuantizationConfigOptions for the next nodes. + next_nodes: Output nodes of current node. + node_qc_options: Node's QuantizationConfigOptions. + + Returns: + + """ + # Filter quantization config options that don't match the graph. + _base_config = node_qc_options.base_config + _node_qc_options = node_qc_options.quantization_config_list + if len(next_nodes): + next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes] + next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits + for qc_opts in next_nodes_qc_options + for op_cfg in qc_opts.quantization_config_list]) + + # Filter node's QC options that match next nodes input bit-width. + _node_qc_options = [_option for _option in _node_qc_options + if _option.activation_n_bits <= next_nodes_supported_input_bitwidth] + if len(_node_qc_options) == 0: + Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover + + # Verify base config match + if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits + for qc_opt in next_nodes_qc_options]): + # base_config activation bits doesn't match next node supported input bit-width -> replace with + # a qco from quantization_config_list with maximum activation bit-width. + if len(_node_qc_options) > 0: + output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)} + _base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]] + Logger.warning(f"Node {self} base quantization config changed to match Graph and TPC configuration.\nCause: {self} -> {next_nodes}.") + else: + Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover + + return _base_config, _node_qc_options + def is_match_type(self, _type: Type) -> bool: """ Check if input type matches the node type, either in instance type or in type name. diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py index 7226e7e31..ef6b26072 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py @@ -195,12 +195,12 @@ def compute_total_bops(graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkI def requires_mixed_precision(in_model: Any, - target_resource_utilization: ResourceUtilization, - representative_data_gen: Callable, - core_config: CoreConfig, - tpc: TargetPlatformCapabilities, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation) -> bool: + target_resource_utilization: ResourceUtilization, + representative_data_gen: Callable, + core_config: CoreConfig, + tpc: TargetPlatformCapabilities, + fw_info: FrameworkInfo, + fw_impl: FrameworkImplementation) -> bool: """ The function checks whether the model requires mixed precision to meet the requested target resource utilization. This is determined by whether the target memory usage of the weights is less than the available memory, diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index 6298fa252..44df33879 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -96,6 +96,7 @@ def __init__(self, self.activation_n_bits = op_cfg.activation_n_bits self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2 self.enable_activation_quantization = op_cfg.enable_activation_quantization + self.is_signed = op_cfg.is_signed self.activation_channel_equalization = qc.activation_channel_equalization self.input_scaling = qc.input_scaling self.min_threshold = qc.min_threshold diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py index 5070b7234..455dc318b 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py @@ -19,7 +19,7 @@ import model_compression_toolkit.core.common.quantization.quantization_config as qc from model_compression_toolkit.constants import LUT_VALUES, MIN_THRESHOLD, SCALE_PER_CHANNEL, \ - LUT_VALUES_BITWIDTH, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES + LUT_VALUES_BITWIDTH, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED from model_compression_toolkit.core.common.hessian import HessianInfoService from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \ max_power_of_two, int_quantization_with_threshold @@ -110,7 +110,8 @@ def lut_kmeans_histogram(bins: np.ndarray, constrained: bool = True, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, - quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> Dict: + quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE, + is_signed: bool = None) -> Dict: """ Finds quantization cluster points for non-uniform activation quantization. The quantizer first finds the closest power-of-two number to the max value of the given histogram, @@ -129,6 +130,7 @@ def lut_kmeans_histogram(bins: np.ndarray, n_iter: Number of iteration ot search for the threshold (not used for this method). min_threshold: Minimal threshold to use if threshold is too small. quant_error_method: an error function to optimize the parameters' selection accordingly (not used for this method). + is_signed: Whether the quantization is signed or not. If None then compute SIGNED value. Returns: A dictionary containing the cluster assignments according to the k-means algorithm and @@ -148,9 +150,9 @@ def lut_kmeans_histogram(bins: np.ndarray, tensor_max = np.max(bins_with_values) threshold = max_power_of_two(tensor_max, min_threshold) - signed = np.any(bins[:-1][counts != 0] < 0) # Whether histogram contains negative values or not. + signed = np.any(bins[:-1][counts != 0] < 0) if is_signed is None else is_signed # Whether histogram contains negative values or not. tensor_for_kmeans = int_quantization_with_threshold(data=bins, threshold=threshold, n_bits=LUT_VALUES_BITWIDTH, signed=signed) kmeans.fit(tensor_for_kmeans.reshape(-1, 1), sample_weight=np.insert(counts, 0, 0)) return {LUT_VALUES: np.float32(np.round(kmeans.cluster_centers_)), - THRESHOLD: threshold} + THRESHOLD: threshold, SIGNED: signed} diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py index 65c1055f4..2d2241424 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py @@ -16,7 +16,7 @@ from typing import Union, Tuple, Dict import model_compression_toolkit.core.common.quantization.quantization_config as qc -from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES +from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED from model_compression_toolkit.core.common.hessian import HessianInfoService from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \ qparams_selection_tensor_search, qparams_selection_histogram_search @@ -105,7 +105,8 @@ def power_of_two_selection_histogram(bins: np.ndarray, constrained: bool = True, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, - quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict: + quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE, + is_signed: bool = None) -> Dict: """ Compute the power of two threshold based on the provided QuantizationErrorMethod to quantize a histogram. Different search is applied, depends on the value of the selected QuantizationErrorMethod. @@ -121,6 +122,7 @@ def power_of_two_selection_histogram(bins: np.ndarray, n_iter: Number of iteration ot search for the threshold (not used for this method). min_threshold: Minimal threshold to use if threshold is too small (used only for kl threshold selection). quant_error_method: an error function to optimize the parameters' selection accordingly. + is_signed: Whether the quantization is signed or not. If None then compute SIGNED value. Returns: Power of two threshold to quantize the histogram a power of 2 manner. @@ -128,17 +130,20 @@ def power_of_two_selection_histogram(bins: np.ndarray, if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING: tensor_max = np.max(np.abs(bins)[1:][counts > 0]) threshold = max_power_of_two(tensor_max, min_threshold) + # Resolve is_signed in case it is None. + signed = (bins<0).any() if is_signed is None else is_signed else: error_function = get_threshold_selection_histogram_error_function(QuantizationMethod.POWER_OF_TWO, quant_error_method, p) - threshold = qparams_selection_histogram_search(error_function, - bins, - counts, - n_bits, - constrained=constrained, - n_iter=n_iter, - min_threshold=min_threshold) - return {THRESHOLD: threshold} + threshold, signed = qparams_selection_histogram_search(error_function, + bins, + counts, + n_bits, + constrained=constrained, + n_iter=n_iter, + min_threshold=min_threshold, + is_signed=is_signed) + return {THRESHOLD: threshold, SIGNED: signed} def power_of_two_no_clipping_selection_min_max(bins: np.ndarray, @@ -151,7 +156,8 @@ def power_of_two_no_clipping_selection_min_max(bins: np.ndarray, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, quant_error_method: qc.QuantizationErrorMethod = - qc.QuantizationErrorMethod.NOCLIPPING) -> dict: + qc.QuantizationErrorMethod.NOCLIPPING, + is_signed: bool = None) -> Dict: """ Gets a threshold between min and max numbers. If computed threshold is less than min_threshold, min_threshold is returned. @@ -168,4 +174,5 @@ def power_of_two_no_clipping_selection_min_max(bins: np.ndarray, constrained, n_iter, min_threshold=min_threshold, - quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING) + quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING, + is_signed=is_signed) diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py index 558027d74..7420ecad8 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== import numpy as np -from typing import Dict +from typing import Dict, Union from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector @@ -25,7 +25,7 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig, nodes_prior_info: NodePriorInfo, - out_stats_container: BaseStatsCollector) -> Dict[str, float]: + out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]: """ Compute the activations params for a given node in a graph according to a params function. @@ -49,7 +49,9 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf bins_counts) min_value, max_value = out_stats_container.get_min_max_values() - if nodes_prior_info.is_output_bounded(): + if activation_quant_cfg.is_signed is not None: + signed = activation_quant_cfg.is_signed + elif nodes_prior_info.is_output_bounded(): signed = min_value < 0 else: signed = np.any(bins_values[:-1][bins_counts > 0] < 0) @@ -65,14 +67,12 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf activation_quant_cfg.activation_quantization_params_fn = \ quantization_params_generation.uniform_no_clipping_selection_min_max - activation_params = activation_quant_cfg.activation_quantization_params_fn(bins_values, - bins_counts, - activation_quant_cfg.l_p_value, - activation_quant_cfg.activation_n_bits, - min_value, - max_value, - min_threshold=activation_quant_cfg.min_threshold, - quant_error_method=activation_quant_cfg.activation_error_method) - activation_params.update({SIGNED: signed}) - - return activation_params + return activation_quant_cfg.activation_quantization_params_fn(bins_values, + bins_counts, + activation_quant_cfg.l_p_value, + activation_quant_cfg.activation_n_bits, + min_value, + max_value, + min_threshold=activation_quant_cfg.min_threshold, + quant_error_method=activation_quant_cfg.activation_error_method, + is_signed=signed) diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py index d5f33326d..aefd84901 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py @@ -107,7 +107,8 @@ def qparams_selection_histogram_search(error_function: Callable, n_bits: int, constrained: bool = True, n_iter: int = 10, - min_threshold: float = MIN_THRESHOLD): + min_threshold: float = MIN_THRESHOLD, + is_signed: bool = None) -> Tuple[np.ndarray, bool]: """ Search for an optimal threshold to quantize a histogram of collected float values. The search_methods starts with the constrained no-clipping threshold by the bins' maximal value, and continues with @@ -123,13 +124,14 @@ def qparams_selection_histogram_search(error_function: Callable, constrained: Whether the threshold should be constrained or not. n_iter: Number of searching iterations. min_threshold: Threshold to return if the computed threshold is smaller that min_threshold. + is_signed: Whether the quantization is signed or not. If None then compute SIGNED value. Returns: Optimal constrained threshold to quantize the tensor. """ - signed = np.any(bins < 0) # Whether histogram contains negative values or not. + signed = (bins < 0).any() if is_signed is None else is_signed # Whether histogram contains negative values or not. tensor_data = np.abs(bins) tensor_max = np.max(tensor_data) if not constrained: @@ -150,7 +152,7 @@ def qparams_selection_histogram_search(error_function: Callable, error_list.append(error) # Return the threshold with the minimal error. - return np.maximum(threshold_list[np.argmin(error_list)], min_threshold) + return np.maximum(threshold_list[np.argmin(error_list)], min_threshold), signed def qparams_symmetric_iterative_minimization(x0: np.ndarray, @@ -537,7 +539,8 @@ def qparams_symmetric_selection_histogram_search(error_function: Callable, counts: np.ndarray, n_bits: int, n_iter: int = SYMMETRIC_HISTOGRAM_N_ITER, - min_threshold: float = MIN_THRESHOLD): + min_threshold: float = MIN_THRESHOLD, + is_signed: bool = None) -> Tuple[np.ndarray, bool]: """ search for optimal threshold (per-channel or per-tensor) for symmetric quantization of a histogram, using the iterative optimizer method. @@ -550,12 +553,13 @@ def qparams_symmetric_selection_histogram_search(error_function: Callable, n_bits: Number of bits to quantize the tensor. n_iter: Number of searching iterations. min_threshold: Threshold to return if the computed threshold is smaller that min_threshold. + is_signed: Whether the quantization is signed or not. If None then compute SIGNED value. Returns: Optimized threshold for quantifying the histogram. """ - signed = np.any(bins[:-1][counts != 0] < 0) # Whether histogram contains negative values or not. + signed = np.any(bins[:-1][counts != 0] < 0) if is_signed is None else is_signed # Whether histogram contains negative values or not. res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max), x=bins, @@ -570,7 +574,7 @@ def qparams_symmetric_selection_histogram_search(error_function: Callable, n_iter=SYMMETRIC_HISTOGRAM_N_ITER, dec_freq=SYMMETRIC_HISTOGRAM_DEC_FREQ, per_channel=False) - return max(min_threshold, res['param']) + return max(min_threshold, res['param']), signed def kl_qparams_symmetric_selection_histogram_search(error_function: Callable, @@ -579,7 +583,8 @@ def kl_qparams_symmetric_selection_histogram_search(error_function: Callable, counts: np.ndarray, n_bits: int, n_iter: int = SYMMETRIC_HISTOGRAM_N_ITER, - min_threshold: float = MIN_THRESHOLD): + min_threshold: float = MIN_THRESHOLD, + is_signed: bool = None) -> Tuple[np.ndarray, bool]: """ Search for optimal threshold (per-channel or per-tensor) for symmetric quantization of a histogram, with KL-Divergence loss function (needs a separate search function @@ -599,7 +604,7 @@ def kl_qparams_symmetric_selection_histogram_search(error_function: Callable, Optimized threshold for quantifying the histogram. """ - signed = np.any(bins[:-1][counts != 0] < 0) # Whether histogram contains negative values or not. + signed = np.any(bins[:-1][counts != 0] < 0) if is_signed is None else is_signed # Whether histogram contains negative values or not. res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max), x=bins, loss_fn=lambda x, q_x, t: @@ -617,7 +622,7 @@ def kl_qparams_symmetric_selection_histogram_search(error_function: Callable, n_iter=SYMMETRIC_HISTOGRAM_N_ITER, dec_freq=SYMMETRIC_HISTOGRAM_DEC_FREQ, per_channel=False) - return max(min_threshold, res['param']) + return max(min_threshold, res['param']), signed def qparams_uniform_selection_histogram_search(error_function: Callable, diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py index b07b1a470..73cb1077d 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py @@ -16,7 +16,7 @@ from typing import Union, Tuple, Dict import model_compression_toolkit.core.common.quantization.quantization_config as qc -from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES +from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED from model_compression_toolkit.core.common.hessian import HessianInfoService from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \ get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function, _kl_error_histogram @@ -106,7 +106,8 @@ def symmetric_selection_histogram(bins: np.ndarray, constrained: bool = True, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, - quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict: + quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE, + is_signed: bool = None) -> Dict: """ Compute the optimal threshold based on the provided QuantizationErrorMethod to quantize a histogram. Different search is applied, depends on the value of the selected QuantizationErrorMethod. @@ -122,6 +123,7 @@ def symmetric_selection_histogram(bins: np.ndarray, n_iter: Number of iteration ot search for the threshold (not used for this method). min_threshold: Minimal threshold to use if threshold is too small (used only for kl threshold selection). quant_error_method: an error function to optimize the parameters' selection accordingly. + is_signed: Whether the quantization is signed or not. If None then compute SIGNED value. Returns: Optimal threshold to quantize the histogram a symmetric manner. @@ -129,23 +131,27 @@ def symmetric_selection_histogram(bins: np.ndarray, tensor_max = np.max(np.abs(bins)[1:][counts > 0]) if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING: threshold = get_init_threshold(min_threshold, tensor_max) + # Resolve is_signed in case it is None. + signed = (bins<0).any() if is_signed is None else is_signed elif quant_error_method == qc.QuantizationErrorMethod.KL: # search for KL error is separated because the error method signature is different from the other error methods. - threshold = kl_qparams_symmetric_selection_histogram_search(_kl_error_histogram, - tensor_max, - bins, - counts, - n_bits, - min_threshold=min_threshold) + threshold, signed = kl_qparams_symmetric_selection_histogram_search(_kl_error_histogram, + tensor_max, + bins, + counts, + n_bits, + min_threshold=min_threshold, + is_signed=is_signed) else: error_function = get_threshold_selection_histogram_error_function(QuantizationMethod.SYMMETRIC, quant_error_method, p) - threshold = qparams_symmetric_selection_histogram_search(error_function, - tensor_max, - bins, - counts, - n_bits, - min_threshold=min_threshold) - return {THRESHOLD: threshold} + threshold, signed = qparams_symmetric_selection_histogram_search(error_function, + tensor_max, + bins, + counts, + n_bits, + min_threshold=min_threshold, + is_signed=is_signed) + return {THRESHOLD: threshold, SIGNED: signed} def symmetric_no_clipping_selection_min_max(bins: np.ndarray, @@ -158,7 +164,8 @@ def symmetric_no_clipping_selection_min_max(bins: np.ndarray, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, quant_error_method: qc.QuantizationErrorMethod = - qc.QuantizationErrorMethod.NOCLIPPING) -> dict: + qc.QuantizationErrorMethod.NOCLIPPING, + is_signed: bool = None) -> Dict: """ Gets a threshold between min and max numbers. If computed threshold is less than min_threshold, min_threshold is returned. @@ -175,7 +182,8 @@ def symmetric_no_clipping_selection_min_max(bins: np.ndarray, constrained, n_iter, min_threshold=min_threshold, - quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING) + quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING, + is_signed=is_signed) def get_init_threshold(min_threshold: float, tensor_max: np.ndarray, per_channel: bool = False) -> np.ndarray: diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py index f64a1e046..624c6bd9f 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py @@ -16,7 +16,7 @@ from typing import Union, Tuple, Dict import model_compression_toolkit.core.common.quantization.quantization_config as qc -from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX, NUM_QPARAM_HESSIAN_SAMPLES +from model_compression_toolkit.constants import MIN_THRESHOLD, RANGE_MIN, RANGE_MAX, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED from model_compression_toolkit.core.common.hessian import HessianInfoService from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \ qparams_uniform_selection_tensor_search, qparams_uniform_selection_histogram_search @@ -114,7 +114,8 @@ def uniform_selection_histogram(bins: np.ndarray, constrained: bool = True, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, - quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict: + quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE, + is_signed: bool = None) -> Dict: """ Compute the optimal quantization range based on the provided QuantizationErrorMethod to uniformly quantize the histogram. @@ -131,6 +132,7 @@ def uniform_selection_histogram(bins: np.ndarray, n_iter: Number of iteration ot search for the threshold (not used for this method). min_threshold: Minimal threshold to use if threshold is too small (not used for this method). quant_error_method: an error function to optimize the range parameters selection accordingly. + is_signed: Whether the quantization is signed or not. If None then compute SIGNED value. Returns: Optimal quantization range to quantize the histogram uniformly. @@ -139,6 +141,7 @@ def uniform_selection_histogram(bins: np.ndarray, tensor_max = np.max(bins[1:][counts > 0]) tensor_min_max = np.array([tensor_min, tensor_max]) + signed = tensor_min < 0 if is_signed is None else is_signed if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING: mm = tensor_min_max else: @@ -150,7 +153,7 @@ def uniform_selection_histogram(bins: np.ndarray, n_bits) return {RANGE_MIN: mm[0], - RANGE_MAX: mm[1]} + RANGE_MAX: mm[1], SIGNED: signed} def uniform_no_clipping_selection_min_max(bins: np.ndarray, @@ -163,7 +166,8 @@ def uniform_no_clipping_selection_min_max(bins: np.ndarray, n_iter: int = 20, min_threshold: float = MIN_THRESHOLD, quant_error_method: qc.QuantizationErrorMethod = - qc.QuantizationErrorMethod.NOCLIPPING) -> dict: + qc.QuantizationErrorMethod.NOCLIPPING, + is_signed: bool = None) -> Dict: """ Gets a quantization rage between min and max numbers. @@ -179,5 +183,5 @@ def uniform_no_clipping_selection_min_max(bins: np.ndarray, constrained, n_iter, min_threshold=min_threshold, - quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING) - + quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING, + is_signed=is_signed) diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 09ef451a2..b09bb99a5 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -64,6 +64,7 @@ def set_quantization_configuration_to_graph(graph: Graph, for n in graph.nodes: set_quantization_configs_to_node(node=n, + graph=graph, quant_config=quant_config, fw_info=graph.fw_info, tpc=graph.tpc, @@ -72,6 +73,7 @@ def set_quantization_configuration_to_graph(graph: Graph, def set_quantization_configs_to_node(node: BaseNode, + graph: Graph, quant_config: QuantizationConfig, fw_info: FrameworkInfo, tpc: TargetPlatformCapabilities, @@ -81,19 +83,22 @@ def set_quantization_configs_to_node(node: BaseNode, Args: node: Node to set its quantization configurations. + graph: Model's internal representation graph. quant_config: Quantization configuration to generate the node's configurations from. fw_info: Information needed for quantization about the specific framework. tpc: TargetPlatformCapabilities to get default OpQuantizationConfig. mixed_precision_enable: is mixed precision enabled. """ node_qc_options = node.get_qco(tpc) + base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options) # Create QC candidates for weights and activation combined weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type) node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config, fw_info, weight_channel_axis, - node_qc_options, + node_qc_options_list, + base_config, node, mixed_precision_enable=mixed_precision_enable) @@ -186,7 +191,8 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig, def _create_node_candidates_qc(qc: QuantizationConfig, fw_info: FrameworkInfo, weight_channel_axis: Tuple[int, int], - node_qc_options: QuantizationConfigOptions, + node_qc_options_list: List[OpQuantizationConfig], + base_config: OpQuantizationConfig, node: BaseNode, mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]: """ @@ -196,7 +202,8 @@ def _create_node_candidates_qc(qc: QuantizationConfig, qc: Quantization configuration the quantization process should follow. fw_info: Framework information (e.g., which layers should have their kernels' quantized). weight_channel_axis: (Output, Input) channel index of the node's kernel. - node_qc_options: QuantizationConfigOptions for the node with quantization candidates information. + node_qc_options_list: List of quantization configs of node. + base_config: Base quantization config for node. node: A node to set quantization configuration candidates to. mixed_precision_enable: is mixed precision enabled @@ -208,7 +215,7 @@ def _create_node_candidates_qc(qc: QuantizationConfig, node_attrs_list = node.get_node_weights_attributes() if mixed_precision_enable: - for op_cfg in node_qc_options.quantization_config_list: + for op_cfg in node_qc_options_list: candidate_qc = copy.deepcopy(qc) candidates.append(_create_node_single_candidate_qc(candidate_qc, fw_info, @@ -220,7 +227,7 @@ def _create_node_candidates_qc(qc: QuantizationConfig, candidates.append(_create_node_single_candidate_qc(qc, fw_info, weight_channel_axis, - node_qc_options.base_config, + base_config, node_attrs_list)) return candidates diff --git a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py index 47bc827f8..6dcc1a6d2 100644 --- a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +++ b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py @@ -349,9 +349,15 @@ def shift_negative_function(graph: Graph, add_node.output_shape, pad_top, pad_btm, pad_left, pad_right) + # Insert a pad node between the add node to the op2d, and create statistics for the pad node + insert_node_before_node(graph, + node_to_insert=pad_node, + last_node=op2d_node) + # Set quantization configuration to node, even though we do not quantize it: set_quantization_configs_to_node(fw_info=fw_info, node=pad_node, + graph=graph, quant_config=core_config.quantization_config, tpc=graph.tpc, mixed_precision_enable=core_config.mixed_precision_enable) @@ -361,11 +367,6 @@ def shift_negative_function(graph: Graph, for attr in pad_node.get_node_weights_attributes(): candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False - # Insert a pad node between the add node to the op2d, and create statistics for the pad node - insert_node_before_node(graph, - node_to_insert=pad_node, - last_node=op2d_node) - graph.set_out_stats_collector_to_node(pad_node, add_node_stats_collector) # We ignore the padding effect on statistics @@ -373,6 +374,7 @@ def shift_negative_function(graph: Graph, set_quantization_configs_to_node(fw_info=fw_info, node=add_node, + graph=graph, quant_config=core_config.quantization_config, tpc=graph.tpc, mixed_precision_enable=core_config.mixed_precision_enable) diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py index 5850a170b..f3bdf5da0 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py @@ -14,7 +14,7 @@ # ============================================================================== import copy -from typing import List, Dict, Union, Any +from typing import List, Dict, Union, Any, Tuple from mct_quantizers import QuantizationMethod from model_compression_toolkit.constants import FLOAT_BITWIDTH @@ -114,11 +114,13 @@ def __init__(self, attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig], activation_quantization_method: QuantizationMethod, activation_n_bits: int, + supported_input_activation_n_bits: Union[int, Tuple[int]], enable_activation_quantization: bool, quantization_preserving: bool, fixed_scale: float, fixed_zero_point: int, - simd_size: int + simd_size: int, + is_signed: bool = None ): """ @@ -127,10 +129,12 @@ def __init__(self, attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration. activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization. activation_n_bits (int): Number of bits to quantize the activations. + supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input. enable_activation_quantization (bool): Whether to quantize the model activations or not. quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output. fixed_scale (float): Scale to use for an operator quantization parameters. fixed_zero_point (int): Zero-point to use for an operator quantization parameters. + is_signed (bool): Force activation quantization signedness (None means don't force). simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction. """ @@ -140,10 +144,17 @@ def __init__(self, self.activation_quantization_method = activation_quantization_method self.activation_n_bits = activation_n_bits + if isinstance(supported_input_activation_n_bits, tuple): + self.supported_input_activation_n_bits = supported_input_activation_n_bits + elif isinstance(supported_input_activation_n_bits, int): + self.supported_input_activation_n_bits = (supported_input_activation_n_bits,) + else: + Logger.critical(f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(supported_input_activation_n_bits)}") # pragma: no cover self.enable_activation_quantization = enable_activation_quantization self.quantization_preserving = quantization_preserving self.fixed_scale = fixed_scale self.fixed_zero_point = fixed_zero_point + self.is_signed = is_signed self.simd_size = simd_size def get_info(self): @@ -193,9 +204,21 @@ def __eq__(self, other): self.attr_weights_configs_mapping == other.attr_weights_configs_mapping and \ self.activation_quantization_method == other.activation_quantization_method and \ self.activation_n_bits == other.activation_n_bits and \ + self.supported_input_activation_n_bits == other.supported_input_activation_n_bits and \ self.enable_activation_quantization == other.enable_activation_quantization and \ + self.is_signed == other.is_signed and \ self.simd_size == other.simd_size + @property + def max_input_activation_n_bits(self) -> int: + """ + Get maximum supported input bit-width. + + Returns: Maximum supported input bit-width. + + """ + return max(self.supported_input_activation_n_bits) + class QuantizationConfigOptions: """ diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py index 0f3b809de..00cc06403 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py @@ -29,6 +29,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v2_lut from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import get_keras_tpc as get_keras_tpc_v3 from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_keras import get_keras_tpc as get_keras_tpc_v3_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import get_keras_tpc as get_keras_tpc_v4 # Keras: TPC versioning keras_tpc_models_dict = {'v1': get_keras_tpc_v1, @@ -38,6 +39,7 @@ 'v2_lut': get_keras_tpc_v2_lut, 'v3': get_keras_tpc_v3, 'v3_lut': get_keras_tpc_v3_lut, + 'v4': get_keras_tpc_v4, LATEST: get_keras_tpc_latest} ############################### @@ -60,6 +62,8 @@ get_pytorch_tpc as get_pytorch_tpc_v3 from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3_lut.tpc_pytorch import \ get_pytorch_tpc as get_pytorch_tpc_v3_lut + from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_pytorch import \ + get_pytorch_tpc as get_pytorch_tpc_v4 # Pytorch: TPC versioning pytorch_tpc_models_dict = {'v1': get_pytorch_tpc_v1, @@ -69,6 +73,7 @@ 'v2_lut': get_pytorch_tpc_v2_lut, 'v3': get_pytorch_tpc_v3, 'v3_lut': get_pytorch_tpc_v3_lut, + 'v4': get_pytorch_tpc_v4, LATEST: get_pytorch_tpc_latest} tpc_dict = {TENSORFLOW: keras_tpc_models_dict, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py index 0b7adf256..a00dd561d 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py @@ -93,6 +93,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -105,6 +106,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py index a66b124f4..48011ef19 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py @@ -89,6 +89,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -101,6 +102,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py index ccaa5c5c9..56dbcbde0 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py @@ -89,6 +89,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -101,6 +102,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza default_weight_attr_config=default_weight_attr_config, attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py index 47dc07303..9cea6762b 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py @@ -95,6 +95,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -107,6 +108,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py index 30fbca1cd..f2f790ae8 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py @@ -91,6 +91,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -103,6 +104,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py index abcf0ca6d..292df836a 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py @@ -95,6 +95,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -107,6 +108,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py index 5fbc2f02e..c09ec6957 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py @@ -91,6 +91,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -103,6 +104,7 @@ def get_op_quantization_configs() -> \ attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py new file mode 100644 index 000000000..a9b845dfa --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +__version__ = 'v4' diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py new file mode 100644 index 000000000..3dc2a8a1b --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -0,0 +1,235 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import List, Tuple + +import model_compression_toolkit as mct +from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS +from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \ + TargetPlatformModel +from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \ + AttributeQuantizationConfig + +tp = mct.target_platform + + +def get_tp_model() -> TargetPlatformModel: + """ + A method that generates a default target platform model, with base 8-bit quantization configuration and 8, 4, 2 + bits configuration list for mixed-precision quantization. + NOTE: in order to generate a target platform model with different configurations but with the same Operators Sets + (for tests, experiments, etc.), use this method implementation as a test-case, i.e., override the + 'get_op_quantization_configs' method and use its output to call 'generate_tp_model' with your configurations. + This version enables metadata by default. + + Returns: A TargetPlatformModel object. + + """ + base_config, mixed_precision_cfg_list, default_config = get_op_quantization_configs() + return generate_tp_model(default_config=default_config, + base_config=base_config, + mixed_precision_cfg_list=mixed_precision_cfg_list, + name='imx500_tp_model') + + +def get_op_quantization_configs() -> \ + Tuple[OpQuantizationConfig, List[OpQuantizationConfig], OpQuantizationConfig]: + """ + Creates a default configuration object for 8-bit quantization, to be used to set a default TargetPlatformModel. + In addition, creates a default configuration objects list (with 8, 4 and 2 bit quantization) to be used as + default configuration for mixed-precision quantization. + + Returns: An OpQuantizationConfig config object and a list of OpQuantizationConfig objects. + + """ + + # TODO: currently, we don't want to quantize any attribute but the kernel by default, + # to preserve the current behavior of MCT, so quantization is disabled for all other attributes. + # Other quantization parameters are set to what we eventually want to quantize by default + # when we enable multi-attributes quantization - THIS NEED TO BE MODIFIED IN ALL TP MODELS! + + # define a default quantization config for all non-specified weights attributes. + default_weight_attr_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=False, # TODO: this will changed to True once implementing multi-attributes quantization + lut_values_bitwidth=None) + + # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). + kernel_base_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.SYMMETRIC, + weights_n_bits=8, + weights_per_channel_threshold=True, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # define a quantization config to quantize the bias (for layers where there is a bias attribute). + bias_config = AttributeQuantizationConfig( + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + weights_n_bits=FLOAT_BITWIDTH, + weights_per_channel_threshold=False, + enable_weights_quantization=False, + lut_values_bitwidth=None) + + # Create a quantization config. + # A quantization configuration defines how an operator + # should be quantized on the modeled hardware: + + # We define a default config for operation without kernel attribute. + # This is the default config that should be used for non-linear operations. + eight_bits_default = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32) + + # We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes. + linear_eight_bits = tp.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, + activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + fixed_scale=None, + fixed_zero_point=None, + simd_size=32) + + # To quantize a model using mixed-precision, create + # a list with more than one OpQuantizationConfig. + # In this example, we quantize some operations' weights + # using 2, 4 or 8 bits, and when using 2 or 4 bits, it's possible + # to quantize the operations' activations using LUT. + four_bits = linear_eight_bits.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}}, + simd_size=linear_eight_bits.simd_size * 2) + two_bits = linear_eight_bits.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 2}}, + simd_size=linear_eight_bits.simd_size * 4) + + mixed_precision_cfg_list = [linear_eight_bits, four_bits, two_bits] + + return linear_eight_bits, mixed_precision_cfg_list, eight_bits_default + + +def generate_tp_model(default_config: OpQuantizationConfig, + base_config: OpQuantizationConfig, + mixed_precision_cfg_list: List[OpQuantizationConfig], + name: str) -> TargetPlatformModel: + """ + Generates TargetPlatformModel with default defined Operators Sets, based on the given base configuration and + mixed-precision configurations options list. + + Args + default_config: A default OpQuantizationConfig to set as the TP model default configuration. + base_config: An OpQuantizationConfig to set as the TargetPlatformModel base configuration for mixed-precision purposes only. + mixed_precision_cfg_list: A list of OpQuantizationConfig to be used as the TP model mixed-precision + quantization configuration options. + name: The name of the TargetPlatformModel. + + Returns: A TargetPlatformModel object. + + """ + # Create a QuantizationConfigOptions, which defines a set + # of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example). + # If the QuantizationConfigOptions contains only one configuration, + # this configuration will be used for the operation quantization: + default_configuration_options = tp.QuantizationConfigOptions([default_config]) + + # Create a QuantizationConfigOptions for quantizing constants in functional ops. + # Constant configuration is similar to the default eight bit configuration except for PoT + # quantization method for the constant. + # Since the constants are not named attributes of the layer, we use the default_weight_attr_config to + # define the desired quantization properties for them. + const_config = default_config.clone_and_edit( + default_weight_attr_config=default_config.default_weight_attr_config.clone_and_edit( + enable_weights_quantization=True, weights_per_channel_threshold=True, + weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO)) + const_configuration_options = tp.QuantizationConfigOptions([const_config]) + + # 16 bits inputs and outputs. Currently, only defined for consts since they are used in operators that + # support 16 bit as input and output. + const_config_input16 = const_config.clone_and_edit( + supported_input_activation_n_bits=(8, 16)) + const_config_input16_output16 = const_config_input16.clone_and_edit( + activation_n_bits=16, is_signed=True) + const_configuration_options_inout16 = tp.QuantizationConfigOptions([const_config_input16_output16, + const_config_input16], + base_config=const_config_input16) + + # Create a TargetPlatformModel and set its default quantization config. + # This default configuration will be used for all operations + # unless specified otherwise (see OperatorsSet, for example): + generated_tpm = tp.TargetPlatformModel(default_configuration_options, add_metadata=True, name=name) + + # To start defining the model's components (such as operator sets, and fusing patterns), + # use 'with' the TargetPlatformModel instance, and create them as below: + with generated_tpm: + # Create an OperatorsSet to represent a set of operations. + # Each OperatorsSet has a unique label. + # If a quantization configuration options is passed, these options will + # be used for operations that will be attached to this set's label. + # Otherwise, it will be a configure-less set (used in fusing): + + generated_tpm.set_simd_padding(is_simd_padding=True) + + # May suit for operations like: Dropout, Reshape, etc. + default_qco = tp.get_default_quantization_config_options() + tp.OperatorsSet("NoQuantization", + default_qco.clone_and_edit(enable_activation_quantization=False, + supported_input_activation_n_bits=(8, 16)) + .clone_and_edit_weight_attribute(enable_weights_quantization=False)) + + # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects + mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list, + base_config=base_config) + + # Define operator sets that use mixed_precision_configuration_options: + conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options) + fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options) + + # Define operations sets without quantization configuration + # options (useful for creating fusing patterns, for example): + any_relu = tp.OperatorsSet("AnyReLU") + add = tp.OperatorsSet("Add", const_configuration_options_inout16) + sub = tp.OperatorsSet("Sub", const_configuration_options_inout16) + mul = tp.OperatorsSet("Mul", const_configuration_options_inout16) + div = tp.OperatorsSet("Div", const_configuration_options) + prelu = tp.OperatorsSet("PReLU") + swish = tp.OperatorsSet("Swish") + sigmoid = tp.OperatorsSet("Sigmoid") + tanh = tp.OperatorsSet("Tanh") + + # Combine multiple operators into a single operator to avoid quantization between + # them. To do this we define fusing patterns using the OperatorsSets that were created. + # To group multiple sets with regard to fusing, an OperatorSetConcat can be created + activations_after_conv_to_fuse = tp.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) + activations_after_fc_to_fuse = tp.OperatorSetConcat(any_relu, swish, sigmoid) + any_binary = tp.OperatorSetConcat(add, sub, mul, div) + + # ------------------- # + # Fusions + # ------------------- # + tp.Fusing([conv, activations_after_conv_to_fuse]) + tp.Fusing([fc, activations_after_fc_to_fuse]) + tp.Fusing([any_binary, any_relu]) + + return generated_tpm diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py new file mode 100644 index 000000000..2d29dad20 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py @@ -0,0 +1,132 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow as tf +from packaging import version + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.constants import FOUND_SONY_CUSTOM_LAYERS +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_DEPTHWISE_KERNEL, \ + KERAS_KERNEL, BIAS_ATTR, BIAS + +if FOUND_SONY_CUSTOM_LAYERS: + from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess + +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose, Identity +else: + from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \ + MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \ + Conv2DTranspose, Identity + +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4 import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_keras_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Keras TargetPlatformCapabilities object with default operation sets to layers mapping. + + Returns: a Keras TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_keras_tpc(name='imx500_tpc_keras_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + + Args: + name: Name of the TargetPlatformCapabilities. + tp_model: TargetPlatformModel object. + + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + keras_tpc = tp.TargetPlatformCapabilities(tp_model, name=name, version=TPC_VERSION) + + no_quant_list = [Identity, + tf.identity, + Reshape, + tf.reshape, + Permute, + tf.transpose, + Flatten, + Cropping2D, + ZeroPadding2D, + Dropout, + MaxPooling2D, + tf.split, + tf.quantization.fake_quant_with_min_max_vars, + tf.math.argmax, + tf.shape, + tf.math.equal, + tf.gather, + tf.cast, + tf.unstack, + tf.compat.v1.gather, + tf.nn.top_k, + tf.__operators__.getitem, + tf.strided_slice, + tf.image.combined_non_max_suppression, + tf.compat.v1.shape] + + if FOUND_SONY_CUSTOM_LAYERS: + no_quant_list.append(SSDPostProcess) + + with keras_tpc: + tp.OperationsSetToLayers("NoQuantization", no_quant_list) + tp.OperationsSetToLayers("Conv", + [Conv2D, + DepthwiseConv2D, + Conv2DTranspose, + tf.nn.conv2d, + tf.nn.depthwise_conv2d, + tf.nn.conv2d_transpose], + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + attr_mapping={ + KERNEL_ATTR: DefaultDict({ + DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL, + tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("FullyConnected", [Dense], + attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)}) + tp.OperationsSetToLayers("AnyReLU", [tf.nn.relu, + tf.nn.relu6, + tf.nn.leaky_relu, + ReLU, + LeakyReLU, + tp.LayerFilterParams(Activation, activation="relu"), + tp.LayerFilterParams(Activation, activation="leaky_relu")]) + tp.OperationsSetToLayers("Add", [tf.add, Add]) + tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract]) + tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply]) + tp.OperationsSetToLayers("Div", [tf.math.divide, tf.math.truediv]) + tp.OperationsSetToLayers("PReLU", [PReLU]) + tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) + tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) + tp.OperationsSetToLayers("Tanh", [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")]) + + return keras_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py new file mode 100644 index 000000000..347a0a3e9 --- /dev/null +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py @@ -0,0 +1,112 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import operator + +import torch +from torch import add, sub, mul, div, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, chunk, unbind, topk, \ + gather, equal, transpose, permute, argmax, squeeze +from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d +from torch.nn import Dropout, Flatten, Hardtanh, Identity +from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU +from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu + +from model_compression_toolkit.defaultdict import DefaultDict +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, PYTORCH_KERNEL, \ + BIAS +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model +import model_compression_toolkit as mct +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4 import __version__ as TPC_VERSION + +tp = mct.target_platform + + +def get_pytorch_tpc() -> tp.TargetPlatformCapabilities: + """ + get a Pytorch TargetPlatformCapabilities object with default operation sets to layers mapping. + + Returns: a Pytorch TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + imx500_tpc_tp_model = get_tp_model() + return generate_pytorch_tpc(name='imx500_tpc_pytorch_tpc', tp_model=imx500_tpc_tp_model) + + +def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): + """ + Generates a TargetPlatformCapabilities object with default operation sets to layers mapping. + Args: + name: Name of the TargetPlatformModel. + tp_model: TargetPlatformModel object. + Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel. + """ + + pytorch_tpc = tp.TargetPlatformCapabilities(tp_model, + name=name, + version=TPC_VERSION) + + # we provide attributes mapping that maps each layer type in the operations set + # that has weights attributes with provided quantization config (in the tp model) to + # its framework-specific attribute name. + # note that a DefaultDict should be provided if not all the layer types in the + # operation set are provided separately in the mapping. + pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL), + BIAS_ATTR: DefaultDict(default_value=BIAS)} + + with pytorch_tpc: + tp.OperationsSetToLayers("NoQuantization", [Identity, + Dropout, + Flatten, + dropout, + flatten, + split, + operator.getitem, + reshape, + unsqueeze, + chunk, + unbind, + torch.Tensor.size, + permute, + transpose, + equal, + argmax, + gather, + topk, + squeeze, + MaxPool2d]) + + tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("FullyConnected", [Linear], + attr_mapping=pytorch_linear_attr_mapping) + tp.OperationsSetToLayers("AnyReLU", [torch.relu, + ReLU, + ReLU6, + LeakyReLU, + relu, + relu6, + leaky_relu, + tp.LayerFilterParams(Hardtanh, min_val=0), + tp.LayerFilterParams(hardtanh, min_val=0)]) + + tp.OperationsSetToLayers("Add", [operator.add, add]) + tp.OperationsSetToLayers("Sub", [operator.sub, sub]) + tp.OperationsSetToLayers("Mul", [operator.mul, mul]) + tp.OperationsSetToLayers("Div", [operator.truediv, div]) + tp.OperationsSetToLayers("PReLU", [PReLU, prelu]) + tp.OperationsSetToLayers("Swish", [SiLU, silu, Hardswish, hardswish]) + tp.OperationsSetToLayers("Sigmoid", [Sigmoid, sigmoid]) + tp.OperationsSetToLayers("Tanh", [Tanh, tanh]) + + return pytorch_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py index 9ad2cc713..056a13305 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py @@ -90,6 +90,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -102,6 +103,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza default_weight_attr_config=default_weight_attr_config, attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py index 9b47e0aec..f42f74cb0 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py @@ -88,6 +88,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza attr_weights_configs_mapping={}, activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, @@ -100,6 +101,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza default_weight_attr_config=default_weight_attr_config, attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_n_bits=8, + supported_input_activation_n_bits=8, enable_activation_quantization=True, quantization_preserving=False, fixed_scale=None, diff --git a/tests/common_tests/helpers/generate_test_tp_model.py b/tests/common_tests/helpers/generate_test_tp_model.py index bbfedfc15..6db2fa2c5 100644 --- a/tests/common_tests/helpers/generate_test_tp_model.py +++ b/tests/common_tests/helpers/generate_test_tp_model.py @@ -15,7 +15,8 @@ import copy from typing import Dict, List, Any -from model_compression_toolkit.constants import FLOAT_BITWIDTH +from model_compression_toolkit.constants import FLOAT_BITWIDTH, ACTIVATION_N_BITS_ATTRIBUTE, \ + SUPPORTED_INPUT_ACTIVATION_NBITS_ATTRIBUTE from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST, KERNEL_ATTR, BIAS_ATTR, \ WEIGHTS_N_BITS from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, QuantizationConfigOptions @@ -30,6 +31,9 @@ def generate_test_tp_model(edit_params_dict, name=""): + # Add "supported_input_activation_n_bits" to match "activation_n_bits" if not defined. + if ACTIVATION_N_BITS_ATTRIBUTE in edit_params_dict and SUPPORTED_INPUT_ACTIVATION_NBITS_ATTRIBUTE not in edit_params_dict: + edit_params_dict[SUPPORTED_INPUT_ACTIVATION_NBITS_ATTRIBUTE] = (edit_params_dict[ACTIVATION_N_BITS_ATTRIBUTE],) base_config, op_cfg_list, default_config = get_op_quantization_configs() # separate weights attribute parameters from the requested param to edit @@ -225,6 +229,7 @@ def generate_test_op_qc(default_weight_attr_config: tp.AttributeQuantizationConf attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config}, activation_n_bits=activation_n_bits, + supported_input_activation_n_bits=activation_n_bits, activation_quantization_method=activation_quantization_method, quantization_preserving=False, fixed_scale=None, diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py new file mode 100644 index 000000000..339a19407 --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py @@ -0,0 +1,88 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import tensorflow as tf + +import model_compression_toolkit as mct +from model_compression_toolkit.constants import TENSORFLOW +from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest + +keras = tf.keras +layers = keras.layers + + +get_op_set = lambda x, x_list: [op_set for op_set in x_list if op_set.name == x][0] + + +class Activation16BitTest(BaseKerasFeatureNetworkTest): + + def get_tpc(self): + tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4') + # Force Mul base_config to 16bit only + mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) + mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + return tpc + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = tf.multiply(inputs, inputs) + x = tf.add(x, np.ones((3,), dtype=np.float32)) + x1 = tf.subtract(x, np.ones((3,), dtype=np.float32)) + x = tf.multiply(x, x1) + x = tf.keras.layers.Conv2D(3, 1)(x) + outputs = tf.divide(x, 2*np.ones((3,), dtype=np.float32)) + return keras.Model(inputs=inputs, outputs=outputs) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + mul1_act_quant = quantized_model.layers[3] + mul2_act_quant = quantized_model.layers[9] + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.num_bits == 16, + "1st mul activation bits should be 16 bits because of following add node.") + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.signed == True, + "1st mul activation should be forced by TPC to be signed, even though activations as all positive.") + self.unit_test.assertTrue(mul2_act_quant.activation_holder_quantizer.num_bits == 8, + "2nd mul activation bits should be 8 bits because of following div node.") + + +class Activation16BitMixedPrecisionTest(Activation16BitTest): + + def get_tpc(self): + tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4') + mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) + mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + mul_op_set.qc_options.quantization_config_list.extend( + [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), + mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) + tpc.layer2qco[tf.multiply].quantization_config_list.extend([ + tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=4), + tpc.layer2qco[tf.multiply].base_config.clone_and_edit(activation_n_bits=2)]) + + return tpc + + def get_resource_utilization(self): + return mct.core.ResourceUtilization(activation_memory=200) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + mul1_act_quant = quantized_model.layers[3] + mul2_act_quant = quantized_model.layers[9] + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.num_bits == 8, + "1st mul activation bits should be 8 bits because of RU.") + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.signed == False, + "1st mul activation should be unsigned, because activations as all positive.") + self.unit_test.assertTrue(mul2_act_quant.activation_holder_quantizer.num_bits == 8, + "2nd mul activation bits should be 8 bits because of following div node.") diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/bn_attributes_quantization_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/bn_attributes_quantization_test.py index b04df1e69..44d672c65 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/bn_attributes_quantization_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/bn_attributes_quantization_test.py @@ -55,6 +55,7 @@ def _generate_bn_quantized_tpm(quantize_linear): default_weight_attr_config=default_attr_cfg, attr_weights_configs_mapping={BETA: bn_attr_cfg, GAMMA: bn_attr_cfg}, activation_n_bits=8, + supported_input_activation_n_bits=8, activation_quantization_method=QuantizationMethod.POWER_OF_TWO, quantization_preserving=False, fixed_scale=None, @@ -65,6 +66,7 @@ def _generate_bn_quantized_tpm(quantize_linear): default_weight_attr_config=default_attr_cfg, attr_weights_configs_mapping={}, activation_n_bits=8, + supported_input_activation_n_bits=8, activation_quantization_method=QuantizationMethod.POWER_OF_TWO, quantization_preserving=False, fixed_scale=None, diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py index df4cc025d..ad2eb74c3 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/weights_mixed_precision_tests.py @@ -460,7 +460,8 @@ def get_resource_utilization(self): def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): conv_layers = get_layers_from_model_by_type(quantized_model, layers.Conv2D) - assert (quantization_info.mixed_precision_cfg == [1, 0]).all() + assert any([(quantization_info.mixed_precision_cfg == [1, 0]).all(), + (quantization_info.mixed_precision_cfg == [0, 1]).all()]) for i in range(32): # quantized per channel self.unit_test.assertTrue( np.unique(conv_layers[0].get_quantized_weights()['kernel'][:, :, :, i]).flatten().shape[0] <= 256) diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 2562f40fc..b8ba5c025 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -142,6 +142,8 @@ from tests.keras_tests.feature_networks_tests.feature_networks.concatination_threshold_update import ConcatThresholdtest from tests.keras_tests.feature_networks_tests.feature_networks.const_quantization_test import ConstQuantizationTest, \ AdvancedConstQuantizationTest +from tests.keras_tests.feature_networks_tests.feature_networks.activation_16bit_test import Activation16BitTest, \ + Activation16BitMixedPrecisionTest from model_compression_toolkit.qat.common.qat_config import TrainingMethod layers = tf.keras.layers @@ -794,9 +796,14 @@ def test_keras_tpcs(self): TpcTest(f'{C.IMX500_TP_MODEL}.v2_lut', self).run_test() TpcTest(f'{C.IMX500_TP_MODEL}.v3', self).run_test() TpcTest(f'{C.IMX500_TP_MODEL}.v3_lut', self).run_test() + TpcTest(f'{C.IMX500_TP_MODEL}.v4', self).run_test() TpcTest(f'{C.TFLITE_TP_MODEL}.v1', self).run_test() TpcTest(f'{C.QNNPACK_TP_MODEL}.v1', self).run_test() + def test_16bit_activations(self): + Activation16BitTest(self).run_test() + Activation16BitMixedPrecisionTest(self).run_test() + if __name__ == '__main__': unittest.main() diff --git a/tests/keras_tests/function_tests/test_custom_layer.py b/tests/keras_tests/function_tests/test_custom_layer.py index ee6bf117c..cc99a0f45 100644 --- a/tests/keras_tests/function_tests/test_custom_layer.py +++ b/tests/keras_tests/function_tests/test_custom_layer.py @@ -65,6 +65,7 @@ def get_tpc(): base_cfg = tp.OpQuantizationConfig(activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO, enable_activation_quantization=True, activation_n_bits=32, + supported_input_activation_n_bits=32, default_weight_attr_config=attr_cfg[DEFAULT_WEIGHT_ATTR_CONFIG], attr_weights_configs_mapping={}, quantization_preserving=False, diff --git a/tests/keras_tests/layer_tests/base_keras_layer_test.py b/tests/keras_tests/layer_tests/base_keras_layer_test.py index b2067ffa8..eee4e9ede 100644 --- a/tests/keras_tests/layer_tests/base_keras_layer_test.py +++ b/tests/keras_tests/layer_tests/base_keras_layer_test.py @@ -76,7 +76,7 @@ def get_tpc(self): return get_quantization_disabled_keras_tpc("float_layer_test") elif self.current_mode == LayerTestMode.QUANTIZED_8_BITS: tp = generate_test_tp_model({'weights_n_bits': 8, - 'activation_n_bits': 8}) + 'activation_n_bits': 8}) return generate_keras_tpc(name="8bit_layer_test", tp_model=tp) else: raise NotImplemented diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py new file mode 100644 index 000000000..a8f9762b7 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from operator import mul +import torch + +import model_compression_toolkit as mct +from model_compression_toolkit.constants import PYTORCH +from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL +from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest + + +get_op_set = lambda x, x_list: [op_set for op_set in x_list if op_set.name == x][0] + + +class Activation16BitNet(torch.nn.Module): + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 1) + self.register_buffer('add_const', torch.rand((3, 1, 1))) + self.register_buffer('sub_const', torch.rand((3, 1, 1))) + self.register_buffer('div_const', 2*torch.ones((3, 1, 1))) + + def forward(self, x): + x = torch.mul(x, x) + x1 = torch.add(x, self.add_const) + x = torch.sub(x, self.sub_const) + x = torch.mul(x, x1) + x = self.conv(x) + x = torch.divide(x, self.div_const) + return x + + +class Activation16BitTest(BasePytorchFeatureNetworkTest): + + def get_tpc(self): + tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v4') + mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) + mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config + tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + return tpc + + def create_networks(self): + return Activation16BitNet() + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + mul1_act_quant = quantized_model.mul_activation_holder_quantizer + mul2_act_quant = quantized_model.mul_1_activation_holder_quantizer + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.num_bits == 16, + "1st mul activation bits should be 16 bits because of following add node.") + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.signed == True, + "1st mul activation should be forced by TPC to be signed, even though activations as all positive.") + self.unit_test.assertTrue(mul2_act_quant.activation_holder_quantizer.num_bits == 8, + "2nd mul activation bits should be 8 bits because of following div node.") + + +class Activation16BitMixedPrecisionTest(Activation16BitTest): + + def get_tpc(self): + tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v4') + mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) + mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config + tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + mul_op_set.qc_options.quantization_config_list.extend( + [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), + mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) + tpc.layer2qco[torch.mul].quantization_config_list.extend([ + tpc.layer2qco[torch.mul].base_config.clone_and_edit(activation_n_bits=4), + tpc.layer2qco[torch.mul].base_config.clone_and_edit(activation_n_bits=2)]) + tpc.layer2qco[mul].quantization_config_list.extend([ + tpc.layer2qco[mul].base_config.clone_and_edit(activation_n_bits=4), + tpc.layer2qco[mul].base_config.clone_and_edit(activation_n_bits=2)]) + + return tpc + + def get_resource_utilization(self): + return mct.core.ResourceUtilization(activation_memory=200) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + mul1_act_quant = quantized_model.mul_activation_holder_quantizer + mul2_act_quant = quantized_model.mul_1_activation_holder_quantizer + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.num_bits == 8, + "1st mul activation bits should be 8 bits because of RU.") + self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.signed == False, + "1st mul activation should be unsigned, because activations as all positive.") + self.unit_test.assertTrue(mul2_act_quant.activation_holder_quantizer.num_bits == 8, + "2nd mul activation bits should be 8 bits because of following div node.") diff --git a/tests/pytorch_tests/model_tests/feature_models/bn_attributes_quantization_test.py b/tests/pytorch_tests/model_tests/feature_models/bn_attributes_quantization_test.py index b07b057e0..f02f35364 100644 --- a/tests/pytorch_tests/model_tests/feature_models/bn_attributes_quantization_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/bn_attributes_quantization_test.py @@ -53,6 +53,7 @@ def _generate_bn_quantized_tpm(quantize_linear): default_weight_attr_config=default_attr_cfg, attr_weights_configs_mapping={BETA: bn_attr_cfg, GAMMA: bn_attr_cfg}, activation_n_bits=8, + supported_input_activation_n_bits=8, activation_quantization_method=QuantizationMethod.POWER_OF_TWO, quantization_preserving=False, fixed_scale=None, @@ -63,6 +64,7 @@ def _generate_bn_quantized_tpm(quantize_linear): default_weight_attr_config=default_attr_cfg, attr_weights_configs_mapping={}, activation_n_bits=8, + supported_input_activation_n_bits=8, activation_quantization_method=QuantizationMethod.POWER_OF_TWO, quantization_preserving=False, fixed_scale=None, diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 8403b44ec..31c2dcce5 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -97,6 +97,8 @@ from tests.pytorch_tests.model_tests.feature_models.const_quantization_test import ConstQuantizationTest, \ AdvancedConstQuantizationTest from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest +from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ + Activation16BitMixedPrecisionTest class FeatureModelsTestRunner(unittest.TestCase): @@ -656,6 +658,10 @@ def test_torch_tpcs(self): TpcTest(f'{C.TFLITE_TP_MODEL}.v1', self).run_test() TpcTest(f'{C.QNNPACK_TP_MODEL}.v1', self).run_test() + def test_16bit_activations(self): + Activation16BitTest(self).run_test() + Activation16BitMixedPrecisionTest(self).run_test() + if __name__ == '__main__': unittest.main() diff --git a/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb b/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb index 16e046bc1..97c750ff4 100644 --- a/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb +++ b/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb @@ -1,421 +1,421 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Structured Pruning of a Fully-Connected Keras Model\n", - "\n", - "[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb) \n", - "\n", - "Welcome to this tutorial, where we will guide you through training, pruning, and retraining a fully connected Keras model. We'll begin by constructing and training a simple neural network using the Keras framework. Following this, we will introduce and apply model pruning using MCT to reduce the size of our network. Finally, we'll retrain our pruned model to recover its degraded performance due to the pruning process.\n", - "\n", - "\n", - "## Installing TensorFlow and Model Compression Toolkit\n", - "\n", - "We start by setting up our environment by installing TensorFlow and Model Compression Toolkit and importing them." - ], - "metadata": { - "id": "UJDzewEYfSN5" - } - }, - { - "cell_type": "code", - "source": [ - "!pip install model-compression-toolkit \n", - "!pip install tensorflow" - ], - "metadata": { - "id": "xTvVA__4NItc" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "Q2bAksKtM0ca" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "import tensorflow_datasets as tfds\n", - "import model_compression_toolkit as mct" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Loading and Preprocessing MNIST\n", - "\n", - "Let's create a function to retrive the train and test parts of MNIST dataset including preprocessing:" - ], - "metadata": { - "id": "tW1xcK_Kf4F_" - } - }, - { - "cell_type": "code", - "source": [ - "def load_and_preprocess_mnist():\n", - " (ds_train, ds_test), ds_info = tfds.load(\n", - " 'mnist',\n", - " split=['train', 'test'],\n", - " shuffle_files=True,\n", - " as_supervised=True,\n", - " with_info=True,\n", - " )\n", - "\n", - " def normalize_img(image, label):\n", - " return tf.cast(image, tf.float32) / 255., label\n", - "\n", - " ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", - " ds_train = ds_train.cache().shuffle(ds_info.splits['train'].num_examples).batch(128).prefetch(tf.data.AUTOTUNE)\n", - " ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE).batch(128)\n", - "\n", - " return ds_train, ds_test\n" - ], - "metadata": { - "id": "fwtJHnflfv_f" - }, - "execution_count": 28, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Creating a Fully-Connected Model\n", - "\n", - "In this tutorial section, we create a simple toy example of a fully connected model to demonstrate the pruning process using MCT. It consists of three dense layers with 128, 64, and 10 neurons.\n", - "\n", - "Notably, MCT's structured pruning will target the first two dense layers for pruning, as these layers offer the opportunity to reduce output channels. This reduction can be effectively propagated by adjusting the input channels of subsequent layers.\n", - "\n", - "Once our model is created, we compile it to prepare the model for training and evaluation.\n" - ], - "metadata": { - "id": "m3vu7-uvgtfC" - } - }, - { - "cell_type": "code", - "source": [ - "def create_model():\n", - " model = tf.keras.models.Sequential([\n", - " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", - " tf.keras.layers.Dense(128, activation='relu'),\n", - " tf.keras.layers.Dense(64, activation='relu'),\n", - " tf.keras.layers.Dense(10)\n", - " ])\n", - " model.compile(\n", - " optimizer=tf.keras.optimizers.Adam(0.001),\n", - " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", - " )\n", - " return model" - ], - "metadata": { - "id": "If3oj5jSjXen" - }, - "execution_count": 29, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Training Dense Model on MNIST\n", - "\n", - "Now, we can train our model using the dataset we load and evaluate it." - ], - "metadata": { - "id": "Q_tK6Xknbtha" - } - }, - { - "cell_type": "code", - "source": [ - "# Load MNIST dataset\n", - "ds_train, ds_test = load_and_preprocess_mnist()\n", - "\n", - "# Train and evaluate the model\n", - "model = create_model()\n", - "model.fit(ds_train, epochs=6, validation_data=ds_test)\n", - "model.evaluate(ds_test)" - ], - "metadata": { - "id": "jQ3_9Z1WllVV" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Dense Model Properties\n", - "\n", - "The model.summary() function in Keras provides a snapshot of the model's architecture, including layers, their types, output shapes, and the number of parameters.\n" - ], - "metadata": { - "id": "ZQHxLrsvcLKH" - } - }, - { - "cell_type": "code", - "source": [ - "model.summary()" - ], - "metadata": { - "id": "oxdespw2eeBW" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Let's break down what we see in our model summary:\n", - "\n", - "- First Dense Layer: A fully connected layer with 128 output channels and 784 input channels.\n", - "\n", - "- Second Dense Layer: A fully connected layer with 64 output channels and 128 input channels.\n", - "\n", - "- Third Dense Layer: The final dense layer with 10 neurons (as per the number of MNIST classes) and 64 input channels.\n", - "\n", - "The total parameters amount to 109,386, which roughly requiers 427.29 KB." - ], - "metadata": { - "id": "GymibwxQehOL" - } - }, - { - "cell_type": "markdown", - "source": [ - "## MCT Structured Pruning\n", - "\n", - "### Create TPC\n", - "\n", - "Firstly, we'll set up the Target Platform Capabilities (TPC) to specify each layer's SIMD (Single Instruction, Multiple Data) size.\n", - "\n", - "In MCT, SIMD plays a crucial role in channel grouping, affecting the pruning decision process based on channel importance for each SIMD group of channels.\n", - "\n", - "We'll use the simplest structured pruning scenario for this demonstration with SIMD=1." - ], - "metadata": { - "id": "RKatTp55emtF" - } - }, - { - "cell_type": "code", - "source": [ - "simd_size = 1\n", - "\n", - "def get_tpc():\n", - " tp = mct.target_platform\n", - " default_config = tp.OpQuantizationConfig(\n", - " simd_size=simd_size,\n", - " # Notice that the model will not be quantized when using the pruning API. For now, use tp.QuantizationMethod.UNIFORM for quantization methods and MCT will ignore it during the pruning process.\n", - " activation_quantization_method=tp.QuantizationMethod.UNIFORM,\n", - " weights_quantization_method=tp.QuantizationMethod.UNIFORM,\n", - " activation_n_bits=None,\n", - " weights_n_bits=None,\n", - " weights_per_channel_threshold=None,\n", - " enable_weights_quantization=None,\n", - " enable_activation_quantization=None,\n", - " quantization_preserving=None,\n", - " fixed_scale=None,\n", - " fixed_zero_point=None,\n", - " weights_multiplier_nbits=None)\n", - "\n", - " default_configuration_options = tp.QuantizationConfigOptions([default_config])\n", - " tp_model = tp.TargetPlatformModel(default_configuration_options)\n", - " tpc = tp.TargetPlatformCapabilities(tp_model)\n", - " return tpc\n", - "\n" - ], - "metadata": { - "id": "wqZ71s70jXhH" - }, - "execution_count": 32, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### Create a Representative Dataset\n", - "\n", - "We are creating a representative dataset to guide our model pruning process for computing importance score for each channel:" - ], - "metadata": { - "id": "SnKxedEgqdSm" - } - }, - { - "cell_type": "code", - "source": [ - "# Create a representative dataset\n", - "ds_train_as_iter = iter(ds_train)\n", - "\n", - "def representative_data_gen() -> list:\n", - " yield [next(ds_train_as_iter)[0]]" - ], - "metadata": { - "id": "SCiXV1s9jswp" - }, - "execution_count": 33, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### Create Resource Utilization constraint\n", - "\n", - "We're defining a resource_utilization limit to constrain the memory usage of our pruned model.\n", - "\n", - "By setting a target that limits the model's weight memory to half of its original size (around 427KB), we aim to achieve a compression ratio of 50%:" - ], - "metadata": { - "id": "nylQtALnr9gN" - } - }, - { - "cell_type": "code", - "source": [ - "# Create a ResourceUtilization object to limit the pruned model weights memory to a certain resource constraint\n", - "dense_model_memory = 427*(2**10) # Original model weights requiers ~427KB\n", - "compression_ratio = 0.5\n", - "\n", - "resource_utilization = mct.core.ResourceUtilization(weights_memory=dense_model_memory*compression_ratio)" - ], - "metadata": { - "id": "doJgwbSxsCbr" - }, - "execution_count": 34, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### Prune Model\n", - "\n", - "We're ready to execute the actual pruning using MCT's keras_pruning_experimental function. The model is pruned according to our defined target Resource Utilization and using the representative dataset generated earlier.\n", - "\n", - "Each channel's importance is measured using LFH (Label-Free-Hessian)\n", - "which approximates the Hessian of the loss function w.r.t model's weights.\n", - "\n", - "In this example, we've used just one score approximation for efficiency. Although this is less time-consuming, it's worth noting that using multiple approximations would yield more precise importance scores in real-world applications. However, this precision comes with a trade-off in terms of longer processing times.\n", - "\n", - "The result is a pruned model and associated pruning information, which includes details about the pruning masks and scores for each layer." - ], - "metadata": { - "id": "xSP6815rsCnc" - } - }, - { - "cell_type": "code", - "source": [ - "num_score_approximations = 1\n", - "\n", - "target_platform_cap = get_tpc()\n", - "pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(\n", - " model=model,\n", - " target_resource_utilization=resource_utilization,\n", - " representative_data_gen=representative_data_gen,\n", - " target_platform_capabilities=target_platform_cap,\n", - " pruning_config=mct.pruning.PruningConfig(num_score_approximations=num_score_approximations)\n", - " )" - ], - "metadata": { - "id": "x4taG-5TxBrp" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### Pruned Model Properties\n", - "\n", - "As before, we can use Keras model's API to observe the new architecture and details of the pruned model:" - ], - "metadata": { - "id": "iPd6ezZN2DNp" - } - }, - { - "cell_type": "code", - "source": [ - "pruned_model.summary()" - ], - "metadata": { - "id": "xZu4gPwz2Ptp" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Retraining Pruned Model\n", - "\n", - "After pruning models, it's common to observe a temporary drop in the model's accuracy. This decline directly results from reducing the model's complexity through pruning." - ], - "metadata": { - "id": "pAheQ9SGxB13" - } - }, - { - "cell_type": "code", - "source": [ - "pruned_model.compile(\n", - " optimizer=tf.keras.optimizers.Adam(0.001),\n", - " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", - ")\n", - "pruned_model.evaluate(ds_test)" - ], - "metadata": { - "id": "Vpihq5fpdeSA" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "However, to recover the performance, we retrain the pruned model, allowing it to adapt to its new, compressed architecture. The model can regain, and sometimes even surpass, its original accuracy through retraining." - ], - "metadata": { - "id": "IHORL34t17bA" - } - }, - { - "cell_type": "code", - "source": [ - "pruned_model.fit(ds_train, epochs=6, validation_data=ds_test)\n", - "pruned_model.evaluate(ds_test)" - ], - "metadata": { - "id": "q00zV9Jmjszo" - }, - "execution_count": null, - "outputs": [] - }, - { + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Structured Pruning of a Fully-Connected Keras Model\n", + "\n", + "[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_pruning_mnist.ipynb) \n", + "\n", + "Welcome to this tutorial, where we will guide you through training, pruning, and retraining a fully connected Keras model. We'll begin by constructing and training a simple neural network using the Keras framework. Following this, we will introduce and apply model pruning using MCT to reduce the size of our network. Finally, we'll retrain our pruned model to recover its degraded performance due to the pruning process.\n", + "\n", + "\n", + "## Installing TensorFlow and Model Compression Toolkit\n", + "\n", + "We start by setting up our environment by installing TensorFlow and Model Compression Toolkit and importing them." + ], + "metadata": { + "id": "UJDzewEYfSN5" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install model-compression-toolkit \n", + "!pip install tensorflow" + ], + "metadata": { + "id": "xTvVA__4NItc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "Q2bAksKtM0ca" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "import model_compression_toolkit as mct" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Loading and Preprocessing MNIST\n", + "\n", + "Let's create a function to retrive the train and test parts of MNIST dataset including preprocessing:" + ], + "metadata": { + "id": "tW1xcK_Kf4F_" + } + }, + { + "cell_type": "code", + "source": [ + "def load_and_preprocess_mnist():\n", + " (ds_train, ds_test), ds_info = tfds.load(\n", + " 'mnist',\n", + " split=['train', 'test'],\n", + " shuffle_files=True,\n", + " as_supervised=True,\n", + " with_info=True,\n", + " )\n", + "\n", + " def normalize_img(image, label):\n", + " return tf.cast(image, tf.float32) / 255., label\n", + "\n", + " ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n", + " ds_train = ds_train.cache().shuffle(ds_info.splits['train'].num_examples).batch(128).prefetch(tf.data.AUTOTUNE)\n", + " ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE).batch(128)\n", + "\n", + " return ds_train, ds_test\n" + ], + "metadata": { + "id": "fwtJHnflfv_f" + }, + "execution_count": 28, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Creating a Fully-Connected Model\n", + "\n", + "In this tutorial section, we create a simple toy example of a fully connected model to demonstrate the pruning process using MCT. It consists of three dense layers with 128, 64, and 10 neurons.\n", + "\n", + "Notably, MCT's structured pruning will target the first two dense layers for pruning, as these layers offer the opportunity to reduce output channels. This reduction can be effectively propagated by adjusting the input channels of subsequent layers.\n", + "\n", + "Once our model is created, we compile it to prepare the model for training and evaluation.\n" + ], + "metadata": { + "id": "m3vu7-uvgtfC" + } + }, + { + "cell_type": "code", + "source": [ + "def create_model():\n", + " model = tf.keras.models.Sequential([\n", + " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", + " tf.keras.layers.Dense(128, activation='relu'),\n", + " tf.keras.layers.Dense(64, activation='relu'),\n", + " tf.keras.layers.Dense(10)\n", + " ])\n", + " model.compile(\n", + " optimizer=tf.keras.optimizers.Adam(0.001),\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", + " )\n", + " return model" + ], + "metadata": { + "id": "If3oj5jSjXen" + }, + "execution_count": 29, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Training Dense Model on MNIST\n", + "\n", + "Now, we can train our model using the dataset we load and evaluate it." + ], + "metadata": { + "id": "Q_tK6Xknbtha" + } + }, + { + "cell_type": "code", + "source": [ + "# Load MNIST dataset\n", + "ds_train, ds_test = load_and_preprocess_mnist()\n", + "\n", + "# Train and evaluate the model\n", + "model = create_model()\n", + "model.fit(ds_train, epochs=6, validation_data=ds_test)\n", + "model.evaluate(ds_test)" + ], + "metadata": { + "id": "jQ3_9Z1WllVV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Dense Model Properties\n", + "\n", + "The model.summary() function in Keras provides a snapshot of the model's architecture, including layers, their types, output shapes, and the number of parameters.\n" + ], + "metadata": { + "id": "ZQHxLrsvcLKH" + } + }, + { + "cell_type": "code", + "source": [ + "model.summary()" + ], + "metadata": { + "id": "oxdespw2eeBW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's break down what we see in our model summary:\n", + "\n", + "- First Dense Layer: A fully connected layer with 128 output channels and 784 input channels.\n", + "\n", + "- Second Dense Layer: A fully connected layer with 64 output channels and 128 input channels.\n", + "\n", + "- Third Dense Layer: The final dense layer with 10 neurons (as per the number of MNIST classes) and 64 input channels.\n", + "\n", + "The total parameters amount to 109,386, which roughly requiers 427.29 KB." + ], + "metadata": { + "id": "GymibwxQehOL" + } + }, + { + "cell_type": "markdown", + "source": [ + "## MCT Structured Pruning\n", + "\n", + "### Create TPC\n", + "\n", + "Firstly, we'll set up the Target Platform Capabilities (TPC) to specify each layer's SIMD (Single Instruction, Multiple Data) size.\n", + "\n", + "In MCT, SIMD plays a crucial role in channel grouping, affecting the pruning decision process based on channel importance for each SIMD group of channels.\n", + "\n", + "We'll use the simplest structured pruning scenario for this demonstration with SIMD=1." + ], + "metadata": { + "id": "RKatTp55emtF" + } + }, + { + "cell_type": "code", + "source": [ + "simd_size = 1\n", + "\n", + "def get_tpc():\n", + " tp = mct.target_platform\n", + " default_config = tp.OpQuantizationConfig(\n", + " simd_size=simd_size,\n", + " # Notice that the model will not be quantized when using the pruning API. For now, use tp.QuantizationMethod.UNIFORM for quantization methods and MCT will ignore it during the pruning process.\n", + " activation_quantization_method=tp.QuantizationMethod.UNIFORM,\n", + " weights_quantization_method=tp.QuantizationMethod.UNIFORM,\n", + " activation_n_bits=None,\n", + " supported_input_activation_n_bits=None,\n", + " weights_n_bits=None,\n", + " weights_per_channel_threshold=None,\n", + " enable_weights_quantization=None,\n", + " enable_activation_quantization=None,\n", + " quantization_preserving=None,\n", + " fixed_scale=None,\n", + " fixed_zero_point=None,\n", + " weights_multiplier_nbits=None)\n", + "\n", + " default_configuration_options = tp.QuantizationConfigOptions([default_config])\n", + " tp_model = tp.TargetPlatformModel(default_configuration_options)\n", + " tpc = tp.TargetPlatformCapabilities(tp_model)\n", + " return tpc\n", + "\n" + ], + "metadata": { + "id": "wqZ71s70jXhH" + }, + "execution_count": 32, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Create a Representative Dataset\n", + "\n", + "We are creating a representative dataset to guide our model pruning process for computing importance score for each channel:" + ], + "metadata": { + "id": "SnKxedEgqdSm" + } + }, + { + "cell_type": "code", + "source": [ + "# Create a representative dataset\n", + "ds_train_as_iter = iter(ds_train)\n", + "\n", + "def representative_data_gen() -> list:\n", + " yield [next(ds_train_as_iter)[0]]" + ], + "metadata": { + "id": "SCiXV1s9jswp" + }, + "execution_count": 33, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Create Resource Utilization constraint\n", + "\n", + "We're defining a resource_utilization limit to constrain the memory usage of our pruned model.\n", + "\n", + "By setting a target that limits the model's weight memory to half of its original size (around 427KB), we aim to achieve a compression ratio of 50%:" + ], + "metadata": { + "id": "nylQtALnr9gN" + } + }, + { + "cell_type": "code", + "source": [ + "# Create a ResourceUtilization object to limit the pruned model weights memory to a certain resource constraint\n", + "dense_model_memory = 427*(2**10) # Original model weights requiers ~427KB\n", + "compression_ratio = 0.5\n", + "\n", + "resource_utilization = mct.core.ResourceUtilization(weights_memory=dense_model_memory*compression_ratio)" + ], + "metadata": { + "id": "doJgwbSxsCbr" + }, + "execution_count": 34, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Prune Model\n", + "\n", + "We're ready to execute the actual pruning using MCT's keras_pruning_experimental function. The model is pruned according to our defined target Resource Utilization and using the representative dataset generated earlier.\n", + "\n", + "Each channel's importance is measured using LFH (Label-Free-Hessian)\n", + "which approximates the Hessian of the loss function w.r.t model's weights.\n", + "\n", + "In this example, we've used just one score approximation for efficiency. Although this is less time-consuming, it's worth noting that using multiple approximations would yield more precise importance scores in real-world applications. However, this precision comes with a trade-off in terms of longer processing times.\n", + "\n", + "The result is a pruned model and associated pruning information, which includes details about the pruning masks and scores for each layer." + ], + "metadata": { + "id": "xSP6815rsCnc" + } + }, + { + "cell_type": "code", + "source": [ + "num_score_approximations = 1\n", + "\n", + "target_platform_cap = get_tpc()\n", + "pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(\n", + " model=model,\n", + " target_resource_utilization=resource_utilization,\n", + " representative_data_gen=representative_data_gen,\n", + " target_platform_capabilities=target_platform_cap,\n", + " pruning_config=mct.pruning.PruningConfig(num_score_approximations=num_score_approximations)\n", + " )" + ], + "metadata": { + "id": "x4taG-5TxBrp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Pruned Model Properties\n", + "\n", + "As before, we can use Keras model's API to observe the new architecture and details of the pruned model:" + ], + "metadata": { + "id": "iPd6ezZN2DNp" + } + }, + { + "cell_type": "code", + "source": [ + "pruned_model.summary()" + ], + "metadata": { + "id": "xZu4gPwz2Ptp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Retraining Pruned Model\n", + "\n", + "After pruning models, it's common to observe a temporary drop in the model's accuracy. This decline directly results from reducing the model's complexity through pruning." + ], + "metadata": { + "id": "pAheQ9SGxB13" + } + }, + { + "cell_type": "code", + "source": [ + "pruned_model.compile(\n", + " optimizer=tf.keras.optimizers.Adam(0.001),\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n", + ")\n", + "pruned_model.evaluate(ds_test)" + ], + "metadata": { + "id": "Vpihq5fpdeSA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "However, to recover the performance, we retrain the pruned model, allowing it to adapt to its new, compressed architecture. The model can regain, and sometimes even surpass, its original accuracy through retraining." + ], + "metadata": { + "id": "IHORL34t17bA" + } + }, + { + "cell_type": "code", + "source": [ + "pruned_model.fit(ds_train, epochs=6, validation_data=ds_test)\n", + "pruned_model.evaluate(ds_test)" + ], + "metadata": { + "id": "q00zV9Jmjszo" + }, + "execution_count": null, + "outputs": [] + }, + { "cell_type": "markdown", - "id": "bb7e1572", "metadata": { "id": "bb7e1572" }, @@ -435,6 +435,5 @@ "limitations under the License.\n" ] } - - ] + ] } diff --git a/tutorials/notebooks/mct_features_notebooks/keras/example_keras_qat.ipynb b/tutorials/notebooks/mct_features_notebooks/keras/example_keras_qat.ipynb index 57ee443cf..b9677772d 100644 --- a/tutorials/notebooks/mct_features_notebooks/keras/example_keras_qat.ipynb +++ b/tutorials/notebooks/mct_features_notebooks/keras/example_keras_qat.ipynb @@ -129,6 +129,7 @@ " BIAS_ATTR: bias_config},\n", " activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,\n", " activation_n_bits=3,\n", + " supported_input_activation_n_bits=8,\n", " enable_activation_quantization=True,\n", " quantization_preserving=False,\n", " fixed_scale=None,\n",