From 2e42d5c9df4964ac51bf1bcedf1317c8d7c2ce44 Mon Sep 17 00:00:00 2001 From: Ofir Gordon Date: Wed, 6 Mar 2024 11:29:18 +0200 Subject: [PATCH] Support multiple weights attributes quantization in core (#964) Adding support in multiple attributes quantization from the MCT core side. There are three main parts to this change: * Modifying the node candidate quantization configurations list to include per-attribute candidates: a new class WeightsAttrQuantizationConfig was added and a config of this type is created for each config that appears in the attribute's TPC configs mapping. * Enabling per-attribute operations. * Restricting several features for kernel-only: Mixed precision, GPTQ, QAT, and pruning currently continue to work only on the kernel attribute. Several features, like SNC, bias correction, etc. are also relevant only for nodes that have a quantized kernel. Currently, only kernel quantization is enabled by default. --- .../core/common/framework_implementation.py | 16 + .../core/common/graph/base_graph.py | 48 ++- .../core/common/graph/base_node.py | 195 ++++++++---- .../graph/virtual_activation_weights_node.py | 17 +- .../mixed_precision/bit_width_setter.py | 46 ++- .../configurable_quantizer_utils.py | 53 ++-- .../mixed_precision/kpi_tools/kpi_data.py | 23 +- .../mixed_precision/kpi_tools/kpi_methods.py | 75 +++-- .../mixed_precision_search_manager.py | 46 +-- .../mixed_precision/sensitivity_evaluation.py | 11 +- .../solution_refinement_procedure.py | 52 +++- .../core/common/model_collector.py | 5 +- .../core/common/network_editors/actions.py | 67 ++-- .../candidate_node_quantization_config.py | 27 +- .../quantization/filter_nodes_candidates.py | 47 ++- .../quantization/node_quantization_config.py | 285 +++++++++++++++--- .../qparams_computation.py | 33 +- .../qparams_weights_computation.py | 54 +--- .../quantization/quantize_graph_weights.py | 34 +-- .../core/common/quantization/quantize_node.py | 53 ++-- .../set_node_quantization_config.py | 80 +++-- .../apply_bias_correction_to_graph.py | 27 +- ...apply_second_moment_correction_to_graph.py | 4 +- .../compute_bias_correction_of_graph.py | 59 ++-- .../statistics_correction.py | 1 - .../substitutions/batchnorm_reconstruction.py | 27 +- .../substitutions/batchnorm_refusing.py | 27 +- .../shift_negative_activation.py | 17 +- .../virtual_activation_weights_composition.py | 1 + .../substitutions/weights_activation_split.py | 18 +- .../common/visualization/nn_visualizer.py | 2 +- .../visualization/tensorboard_writer.py | 2 +- model_compression_toolkit/core/exporter.py | 4 +- .../mixed_precision_model_builder.py | 58 ++-- .../core/keras/constants.py | 2 +- .../substitutions/batchnorm_folding.py | 11 +- .../substitutions/input_scaling.py | 9 +- .../core/keras/keras_implementation.py | 43 ++- .../configurable_activation_quantizer.py | 6 +- .../configurable_weights_quantizer.py | 15 +- .../mixed_precision_model_builder.py | 52 ++-- .../back2framework/pytorch_model_builder.py | 2 +- .../configurable_activation_quantizer.py | 6 +- .../configurable_weights_quantizer.py | 15 +- .../core/pytorch/pytorch_implementation.py | 22 ++ .../core/quantization_prep_runner.py | 4 +- model_compression_toolkit/core/runner.py | 2 +- .../model_wrapper/fw_agnostic/__init__.py | 14 + .../get_inferable_quantizers.py} | 26 +- .../builder/fully_quantized_model_builder.py | 23 +- .../keras/builder/node_to_quantizer.py | 59 ++-- .../keras/builder/node_to_quantizers.py | 46 --- .../builder/fully_quantized_model_builder.py | 20 +- .../pytorch/builder/node_to_quantizer.py | 50 +-- .../gptq/common/gptq_graph.py | 4 +- .../gptq/keras/gptq_training.py | 16 +- .../keras/quantizer/quantization_builder.py | 22 +- .../gptq/pytorch/gptq_training.py | 25 +- .../pytorch/quantizer/quantization_builder.py | 18 +- .../qat/common/qat_config.py | 8 +- .../qat/keras/quantization_facade.py | 6 +- .../keras/quantizer/quantization_builder.py | 24 +- .../qat/pytorch/quantization_facade.py | 7 +- .../pytorch/quantizer/quantization_builder.py | 24 +- .../target_platform/op_quantization_config.py | 11 +- .../common/get_quantizer_config.py | 52 ++-- .../helpers/prep_graph_for_func_test.py | 4 +- .../feature_networks/lut_quantizer.py | 9 +- .../network_editor/change_qc_attr_test.py | 1 + .../network_editor/edit_qc_test.py | 22 +- .../network_editor/node_filter_test.py | 27 +- .../feature_networks/test_kmeans_quantizer.py | 3 + .../test_cfg_candidates_filter.py | 21 +- .../test_set_layer_to_bitwidth.py | 34 ++- ...t_symmetric_threshold_selection_weights.py | 5 +- .../test_uniform_range_selection_weights.py | 9 +- .../layer_tests/test_layers_runner.py | 5 +- .../test_lp_search_bitwidth.py | 5 +- .../set_layer_to_bitwidth_test.py | 40 ++- .../layer_tests/base_pytorch_layer_test.py | 1 + .../feature_models/lut_quantizer_test.py | 2 + 81 files changed, 1505 insertions(+), 841 deletions(-) create mode 100644 model_compression_toolkit/exporter/model_wrapper/fw_agnostic/__init__.py rename model_compression_toolkit/exporter/model_wrapper/{pytorch/builder/node_to_quantizers.py => fw_agnostic/get_inferable_quantizers.py} (64%) delete mode 100644 model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py diff --git a/model_compression_toolkit/core/common/framework_implementation.py b/model_compression_toolkit/core/common/framework_implementation.py index 4896a93c8..cb57e4405 100644 --- a/model_compression_toolkit/core/common/framework_implementation.py +++ b/model_compression_toolkit/core/common/framework_implementation.py @@ -449,3 +449,19 @@ def sensitivity_eval_inference(self, """ raise NotImplemented(f'{self.__class__.__name__} have to implement the ' f'framework\'s sensitivity_eval_inference method.') # pragma: no cover + + def get_inferable_quantizers(self, node: BaseNode): + """ + Returns sets of framework compatible weights and activation quantizers for the given node. + + Args: + node: Node to get quantizers for. + + Returns: + weight_quantizers: A dictionary between a weight's name to its quantizer. + activation_quantizers: A list of activations quantization, one for each layer output. + + """ + + raise NotImplemented(f'{self.__class__.__name__} have to implement the ' + f'framework\'s get_inferable_quantizers method.') # pragma: no cover \ No newline at end of file diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index 3eed66150..16e85cce2 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -529,6 +529,7 @@ def get_float_memory(self) -> float: return memory def get_configurable_sorted_nodes_names(self, + fw_info: FrameworkInfo, include_reused_nodes: bool = False) -> List[str]: """ Get a list of nodes' names that can be configured (namely, has one or @@ -536,45 +537,53 @@ def get_configurable_sorted_nodes_names(self, order of the graph. Args: + fw_info: FrameworkInfo object with information about the specific framework's model. include_reused_nodes: Whether or not to include reused nodes (False by default). Returns: List of nodes' names that can be configured (namely, has one or more weight qc candidate) sorted topology. """ - sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes=include_reused_nodes)] + sorted_names = [n.name for n in self.get_configurable_sorted_nodes(fw_info=fw_info, + include_reused_nodes=include_reused_nodes)] return sorted_names def get_weights_configurable_nodes(self, + fw_info: FrameworkInfo, include_reused_nodes: bool = False) -> List[BaseNode]: """ Get a list of nodes that their weights can be configured (namely, has one or more weight qc candidate and their weights should be quantized). Args: + fw_info: FrameworkInfo object with information about the specific framework's model. include_reused_nodes: Whether to include reused nodes (False by default). Returns: A list of nodes that their weights can be configured (namely, has one or more weight qc candidate). """ - return list(filter(lambda n: n.is_weights_quantization_enabled() - and not n.is_all_weights_candidates_equal() - and (not n.reuse or include_reused_nodes), list(self))) + # configurability is only relevant for kernel attribute quantization + potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)] + return list(filter(lambda n: n.is_weights_quantization_enabled(fw_info.get_kernel_op_attributes(n.type)[0]) + and not n.is_all_weights_candidates_equal(fw_info.get_kernel_op_attributes(n.type)[0]) + and (not n.reuse or include_reused_nodes), potential_conf_nodes)) def get_sorted_weights_configurable_nodes(self, + fw_info: FrameworkInfo, include_reused_nodes: bool = False) -> List[BaseNode]: """ Get a list of sorted nodes that their weights can be configured (namely, has one or more weight qc candidate and their weights should be quantized). Args: + fw_info: FrameworkInfo object with information about the specific framework's model. include_reused_nodes: Whether to include reused nodes (False by default). Returns: A list of nodes that their weights can be configured (namely, has one or more weight qc candidate) sorted topologically. """ - return self._sort_nodes_in_list(self.get_weights_configurable_nodes(include_reused_nodes)) + return self._sort_nodes_in_list(self.get_weights_configurable_nodes(fw_info, include_reused_nodes)) def get_activation_configurable_nodes(self) -> List[BaseNode]: """ @@ -599,6 +608,7 @@ def get_sorted_activation_configurable_nodes(self) -> List[BaseNode]: return self._sort_nodes_in_list(self.get_activation_configurable_nodes()) def get_configurable_sorted_nodes(self, + fw_info: FrameworkInfo, include_reused_nodes: bool = False) -> List[BaseNode]: """ Get a list of nodes that can be configured (namely, has one or @@ -606,13 +616,14 @@ def get_configurable_sorted_nodes(self, The nodes are sorted according to the topological order of the graph. Args: + fw_info: fw_info: FrameworkInfo object with information about the specific framework's model. include_reused_nodes: Whether or not to include reused nodes (False by default). Returns: A list of nodes that can be configured (namely, has one or more qc candidate) sorted topology. """ - weights_configurable_nodes = self.get_weights_configurable_nodes(include_reused_nodes) + weights_configurable_nodes = self.get_weights_configurable_nodes(fw_info, include_reused_nodes) activation_configurable_nodes = self.get_activation_configurable_nodes() # combine and remove duplications @@ -637,17 +648,20 @@ def _sort_nodes_in_list(self, nodes_list: List[BaseNode]) -> List[BaseNode]: sorted_configurable_nodes.append(n) return sorted_configurable_nodes - def get_min_candidates_config(self) -> List[int]: + def get_min_candidates_config(self, fw_info: FrameworkInfo) -> List[int]: """ Builds a minimal configuration. Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate with minimal n_bits (in both weight and activation if both are quantized, or in the relevant one if only one of them is quantized) + Args: + fw_info: fw_info: FrameworkInfo object with information about the specific framework's model. + Returns: A list of candidate for each node (list on indices) """ - conf_sorted_nodes = self.get_configurable_sorted_nodes() + conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info) min_cfg_candidates = [n.find_min_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices assert all([len(lst) == 1 for lst in min_cfg_candidates]), \ @@ -655,17 +669,20 @@ def get_min_candidates_config(self) -> List[int]: return [lst[0] for lst in min_cfg_candidates] - def get_max_candidates_config(self) -> List[int]: + def get_max_candidates_config(self, fw_info: FrameworkInfo) -> List[int]: """ Builds a maximal configuration. Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate with maximal n_bits (in both weight and activation if both are quantized, or in the relevant one if only one of them is quantized) + Args: + fw_info: fw_info: FrameworkInfo object with information about the specific framework's model. + Returns: A list of candidate for each node (list on indices) """ - conf_sorted_nodes = self.get_configurable_sorted_nodes() + conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info) max_cfg_candidates = [n.find_max_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices assert all([len(lst) == 1 for lst in max_cfg_candidates]), \ @@ -673,15 +690,20 @@ def get_max_candidates_config(self) -> List[int]: return [lst[0] for lst in max_cfg_candidates] - def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]: + def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]: """ Gets the final number of bits for quantization of each weights' configurable layer. + Args: + fw_info: fw_info: FrameworkInfo object with information about the specific framework's model. + Returns: A list of pairs of (node type, node's weights quantization bitwidth). """ - sorted_conf_weights = self.get_sorted_weights_configurable_nodes() - return [(n, n.final_weights_quantization_cfg.weights_n_bits) for n in sorted_conf_weights] + sorted_conf_weights = self.get_sorted_weights_configurable_nodes(fw_info) + # a configurable node by definition has a kernel op + return [(n, n.final_weights_quantization_cfg.get_attr_config(self.fw_info.get_kernel_op_attributes(n.type)[0]).weights_n_bits) + for n in sorted_conf_weights] def get_final_activation_config(self) -> List[Tuple[BaseNode, int]]: """ diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index ce46177fb..e2cf9b5ed 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -20,6 +20,7 @@ from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \ ACTIVATION_NBITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \ TargetPlatformCapabilities, LayerFilterParams @@ -110,20 +111,29 @@ def is_activation_quantization_enabled(self) -> bool: qc.activation_quantization_cfg.enable_activation_quantization return self.candidates_quantization_cfg[0].activation_quantization_cfg.enable_activation_quantization - def is_weights_quantization_enabled(self) -> bool: + def is_weights_quantization_enabled(self, attr_name: str) -> bool: """ + Checks whether a node's weights attribute quantization is enabled. + + Args: + attr_name: An attribute to check if its quantization is enabled. Returns: Whether node weights quantization is enabled or not. """ if self.final_weights_quantization_cfg: # if we have a final configuration, then we only care to check if it enables weights quantization - return self.final_weights_quantization_cfg.enable_weights_quantization + return self.final_weights_quantization_cfg.get_attr_config(attr_name).enable_weights_quantization - for qc in self.candidates_quantization_cfg: - assert self.candidates_quantization_cfg[0].weights_quantization_cfg.enable_weights_quantization == \ - qc.weights_quantization_cfg.enable_weights_quantization - return self.candidates_quantization_cfg[0].weights_quantization_cfg.enable_weights_quantization + attr_candidates = self.get_all_weights_attr_candidates(attr_name) + candidates_enable_quantization = [c.enable_weights_quantization for c in attr_candidates] + if len(candidates_enable_quantization) > 0 and len(set(candidates_enable_quantization)) > 1: + Logger.error(f"Weights attribute {attr_name} in node {self.name} has multiple quantization candidates " + f"configuration with incompatible values.") + if all(candidates_enable_quantization): + return True + + return False def __repr__(self): """ @@ -182,8 +192,15 @@ def get_weights_list(self): Returns: A list of all non-positional weights the node holds. """ - return [self.weights[k] for k in self.weights.keys() - if self.weights[k] is not None and not isinstance(k, int)] + return [self.weights[k] for k in self.weights.keys() if not isinstance(k, int)] + + def get_node_weights_attributes(self) -> List[str]: + """ + + Returns: A list of all weights attributes that the node holds. + + """ + return list(self.weights.keys()) def insert_positional_weights_to_input_list(self, input_tensors: List) -> List: """ @@ -240,11 +257,18 @@ def get_memory_bytes(self, fw_info) -> float: Returns: Number of bytes the node's memory requires. """ + # TODO: this method is used for tensorboard only. If we want to enable logging of other attributes memory + # then it needs to be modified. But, it might be better to remove this method from the BaseNode completely. + kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0] + if kernel_attr is None: + return 0 q_params, f_params = self.get_num_parameters(fw_info) if self.final_weights_quantization_cfg is None: # float coefficients memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER else: - memory = (f_params * FP32_BYTES_PER_PARAMETER) + (q_params * self.final_weights_quantization_cfg.weights_n_bits / 8) # in bytes + memory = ((f_params * FP32_BYTES_PER_PARAMETER) + + (q_params * self.final_weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits + / 8)) # in bytes return memory @@ -261,29 +285,42 @@ def get_float_memory_bytes(self, fw_info) -> float: q_params, f_params = self.get_num_parameters(fw_info) return (f_params + q_params) * FP32_BYTES_PER_PARAMETER - def get_unified_weights_candidates_dict(self): + def get_unified_weights_candidates_dict(self, fw_info) -> Dict[str, Any]: """ - In Mixed-Precision, a node can have multiple candidates for weights quantization configuration. + In Mixed-Precision, a node's kernel can have multiple candidates for weights quantization configuration. In order to display a single view of a node (for example, for logging in TensorBoard) we need a way to create a single dictionary from all candidates. This method is aimed to build such an unified dictionary for a node. + Args: + fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize). + Returns: A dictionary containing information from node's weight quantization configuration candidates. """ - shared_attributes = [CORRECTED_BIAS_ATTRIBUTE, WEIGHTS_NBITS_ATTRIBUTE] - attr = dict() - if self.is_weights_quantization_enabled(): - attr = copy.deepcopy(self.candidates_quantization_cfg[0].weights_quantization_cfg.__dict__) - for shared_attr in shared_attributes: - if shared_attr in attr: - unified_attr = [] - for candidate in self.candidates_quantization_cfg: - unified_attr.append(getattr(candidate.weights_quantization_cfg, shared_attr)) - attr[shared_attr] = unified_attr - return attr - - def get_unified_activation_candidates_dict(self): + shared_parameters = [CORRECTED_BIAS_ATTRIBUTE, WEIGHTS_NBITS_ATTRIBUTE] + parameters_dict = dict() + # We assume that only the kernel attribute have more than one candidate, since we only allow to + # quantize the kernel using mixed precision + # TODO: need to modify if we want to present a unified config for other attributes + kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0] + if kernel_attr is None: + # This node doesn't have a kernel attribute + return {} + + if self.is_weights_quantization_enabled(kernel_attr): + parameters_dict = copy.deepcopy(self.candidates_quantization_cfg[0].weights_quantization_cfg. + get_attr_config(kernel_attr).__dict__) + for shared_parameter in shared_parameters: + if shared_parameter in parameters_dict: + unified_param = [] + attr_candidates = self.get_all_weights_attr_candidates(kernel_attr) + for attr_candidate in attr_candidates: + unified_param.append(getattr(attr_candidate, shared_parameter)) + parameters_dict[shared_parameter] = unified_param + return parameters_dict + + def get_unified_activation_candidates_dict(self) -> Dict[str, Any]: """ In Mixed-Precision, a node can have multiple candidates for activation quantization configuration. In order to display a single view of a node (for example, for logging in TensorBoard) we need a way @@ -295,7 +332,7 @@ def get_unified_activation_candidates_dict(self): """ shared_attributes = [ACTIVATION_NBITS_ATTRIBUTE] attr = dict() - if self.is_weights_quantization_enabled(): + if self.is_activation_quantization_enabled(): attr = copy.deepcopy(self.candidates_quantization_cfg[0].activation_quantization_cfg.__dict__) for shared_attr in shared_attributes: if shared_attr in attr: @@ -305,7 +342,7 @@ def get_unified_activation_candidates_dict(self): attr[shared_attr] = unified_attr return attr - def is_all_activation_candidates_equal(self): + def is_all_activation_candidates_equal(self) -> bool: """ Checks whether all candidates' quantization configuration have the same activation configuration, using the self-implemented __eq__ method of class NodeActivationQuantizationConfig. @@ -317,23 +354,31 @@ def is_all_activation_candidates_equal(self): self.candidates_quantization_cfg[0].activation_quantization_cfg for candidate in self.candidates_quantization_cfg) - def is_all_weights_candidates_equal(self): + def is_all_weights_candidates_equal(self, attr: str) -> bool: """ - Checks whether all candidates' quantization configuration have the same weights configuration, + Checks whether all candidates' quantization configuration of a given weights attribute + have the same weights configuration, using the self-implemented __eq__ method of class NodeWeightsQuantizationConfig. - Returns: True if all candidates have same weights configuration, False otherwise. + Args: + attr: The attribute name to check if all its quantization configuration candidates are equal. + + Returns: True if all the weights attribute candidates have same configuration, False otherwise. """ - return all(candidate.weights_quantization_cfg == - self.candidates_quantization_cfg[0].weights_quantization_cfg - for candidate in self.candidates_quantization_cfg) + # note that if the given attribute name does not exist in the node's attributes mapping, + # the inner method would log an exception. + return all(attr_candidate == + self.candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(attr) + for attr_candidate in self.get_all_weights_attr_candidates(attr)) - def has_weights_to_quantize(self, fw_info): + def has_kernel_weight_to_quantize(self, fw_info): """ - Checks whether the node has weights that need to be quantized according to the framework info. + Checks whether the node has kernel attribute that need to be quantized according to the framework info. + Args: fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize). + Returns: Whether the node has weights that need to be quantized. """ attrs = fw_info.get_kernel_op_attributes(self.type) @@ -342,6 +387,17 @@ def has_weights_to_quantize(self, fw_info): return True return False + def has_any_weight_attr_to_quantize(self) -> bool: + """ + Checks whether the node has any weights attribute that is supposed to be quantized, based on its provided + quantization configuration candidates. + + Returns: True if the is at least one weights attribute in the node that is supposed to be quantized. + + """ + + return any([self.is_weights_quantization_enabled(attr) for attr in self.get_node_weights_attributes()]) + def get_total_output_params(self) -> float: """ Calculates the output size of the node. @@ -373,7 +429,7 @@ def find_min_candidates_indices(self) -> List[int]: """ Returns a list with potential minimal candidates. A potential minimal candidate is a candidate which its weights_n_bits and activation_n_bits pair is - on the Pareto Front, i.e., there is no other candidates that its n_bits pair exceeds in both entries. + on the Pareto Front, i.e., there is no other candidate that its n_bits pair exceeds in both entries. Returns: A list of indices of potential minimal candidates. @@ -413,20 +469,29 @@ def find_max_candidates_indices(self) -> List[int]: return [i for i, a_n_bits in max_candidates] - def get_unique_weights_candidates(self) -> List[Any]: + def get_unique_weights_candidates(self, attr: str) -> List[Any]: """ - Returns a list with node's candidates of unique weights bit-width value. - If the node have multiple candidates with the same weights bit-width, + Returns a list with node's candidates of unique weights bit-width value for the given attribute. + If the node have multiple candidates with the same weights bit-width for this attribute, the first candidate in the list is returned. - Returns: A list with node's candidates of unique weights bit-width value. + Args: + attr: A weights attribute name to get its unique candidates list. + + Returns: A list with node's candidates of unique weights bit-width value for the given attribute. """ + if attr is None or len(self.get_all_weights_attr_candidates(attr)) == 0: + Logger.warning(f"Trying to retrieve quantization configuration candidates for attribute '{attr}', " + f"but such attribute can't be found in node {self.name}." + f"An empty list of candidates is returned.") + return [] + unique_candidates = copy.deepcopy(self.candidates_quantization_cfg) seen_candidates = set() unique_candidates = [candidate for candidate in unique_candidates if - candidate.weights_quantization_cfg not in seen_candidates - and not seen_candidates.add(candidate.weights_quantization_cfg)] + candidate.weights_quantization_cfg.get_attr_config(attr) not in seen_candidates + and not seen_candidates.add(candidate.weights_quantization_cfg.get_attr_config(attr))] return unique_candidates def get_unique_activation_candidates(self) -> List[Any]: @@ -445,16 +510,6 @@ def get_unique_activation_candidates(self) -> List[Any]: and not seen_candidates.add(candidate.activation_quantization_cfg)] return unique_candidates - def has_weights_quantization_enabled_candidate(self) -> bool: - """ - Checks whether the node has quantization configuration candidates that enable weights quantization. - - Returns: True if the node has at list one quantization configuration candidate with weights quantization enabled. - """ - - return len(self.candidates_quantization_cfg) > 0 and \ - any([c.weights_quantization_cfg.enable_weights_quantization for c in self.candidates_quantization_cfg]) - def has_activation_quantization_enabled_candidate(self) -> bool: """ Checks whether the node has quantization configuration candidates that enable activation quantization. @@ -465,6 +520,20 @@ def has_activation_quantization_enabled_candidate(self) -> bool: return len(self.candidates_quantization_cfg) > 0 and \ any([c.activation_quantization_cfg.enable_activation_quantization for c in self.candidates_quantization_cfg]) + def get_all_weights_attr_candidates(self, attr: str) -> List[WeightsAttrQuantizationConfig]: + """ + Returns all WeightsAttrQuantizationConfig configuration of the given attribute of the node. + + Args: + attr: The attribute name to get its configurations. + + Returns: A list of the attribute's quantization configurations. + + """ + # note that if the given attribute name does not exist in the node's attributes mapping, + # the inner method would log an exception. + return [c.weights_quantization_cfg.get_attr_config(attr) for c in self.candidates_quantization_cfg] + def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions: """ Get the QuantizationConfigOptions of the node according @@ -540,3 +609,27 @@ def get_simd(self) -> int: if _simd <= 0 or int(_simd) != _simd: Logger.error(f"SIMD is expected to be a non-positive integer but found: {_simd}") return _simd + + def sort_node_candidates(self, fw_info): + """ + Sorts the node candidates. + We assume that the candidates are ordered in the following way (for mixed precision purposes): + - If the node has a kernel attribute, then we use the kernel weights number of bits to sort the candidates + (in descending order). We use the candidate activation number of bits as a secondary order. + - If the node doesn't have a kernel we only consider the candidate activation number of bits to sort + the candidates in descending order. + The operation is done inplace. + + Args: + fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize). + + """ + if self.candidates_quantization_cfg is not None: + kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0] + if kernel_attr is not None: + self.candidates_quantization_cfg.sort( + key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits, + c.activation_quantization_cfg.activation_n_bits), reverse=True) + else: + self.candidates_quantization_cfg.sort(key=lambda c: c.activation_quantization_cfg.activation_n_bits, + reverse=True) diff --git a/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py b/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py index 00a014d52..b64d3df49 100644 --- a/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +++ b/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py @@ -24,6 +24,7 @@ from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ CandidateNodeQuantizationConfig +from model_compression_toolkit.logger import Logger class VirtualSplitNode(BaseNode): @@ -60,19 +61,20 @@ class VirtualSplitWeightsNode(VirtualSplitNode): config. """ - def __init__(self, origin_node: BaseNode): + def __init__(self, origin_node: BaseNode, kernel_attr: str): """ Init a VirtualSplitWeightsNode object. Args: origin_node: The original node from which the new node was split. + kernel_attr: The name of the kernel attribute of the original node. """ super().__init__(origin_node) self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX - self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates() + self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates(kernel_attr) for c in self.candidates_quantization_cfg: c.activation_quantization_cfg.enable_activation_quantization = False c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH @@ -130,6 +132,7 @@ def __init__(self, output_shape: Tuple[Any], weights: Dict[str, np.ndarray], layer_class: type, + fw_info: FrameworkInfo, reuse: bool = False, reuse_group: str = None, quantization_attr: Dict[str, Any] = None, @@ -147,6 +150,7 @@ def __init__(self, output_shape: Input tensor shape of the node. weights: Dictionary from a variable name to the weights with that name in the layer the node represents. layer_class: Class path of the layer this node represents. + fw_info: A FrameworkInfo object with framework specific information, reuse: Whether this node was duplicated and represents a reused layer. reuse_group: Name of group of nodes from the same reused layer. quantization_attr: Attributes the node holds regarding how it should be quantized. @@ -180,7 +184,8 @@ def __init__(self, v_candidates.append(composed_candidate) # sorting the candidates by weights number of bits first and then by activation number of bits (reversed order) - v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.weights_n_bits, + kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0] + v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits, c.activation_quantization_cfg.activation_n_bits), reverse=True) self.candidates_quantization_cfg = v_candidates @@ -197,10 +202,12 @@ def get_bops_count(self, fw_impl: Any, fw_info: FrameworkInfo, candidate_idx: in Returns: The BOPS count of the composed node. """ + kernel_attr = fw_info.get_kernel_op_attributes(self.original_weights_node.type)[0] node_mac = fw_impl.get_node_mac_operations(self.original_weights_node, fw_info) candidate = self.candidates_quantization_cfg[candidate_idx] - weights_bit = candidate.weights_quantization_cfg.weights_n_bits if \ - candidate.weights_quantization_cfg.enable_weights_quantization else FLOAT_BITWIDTH + kernel_attr_cfg = candidate.weights_quantization_cfg.get_attr_config(kernel_attr) + weights_bit = kernel_attr_cfg.weights_n_bits if \ + kernel_attr_cfg.enable_weights_quantization else FLOAT_BITWIDTH activation_bit = candidate.activation_quantization_cfg.activation_n_bits if \ candidate.activation_quantization_cfg.enable_activation_quantization else FLOAT_BITWIDTH node_bops = weights_bit * activation_bit * node_mac diff --git a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py index 137678e8e..5b746eed1 100644 --- a/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +++ b/model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py @@ -36,35 +36,40 @@ def set_bit_widths(mixed_precision_enable: bool, """ if mixed_precision_enable: - assert all([len(n.candidates_quantization_cfg) > 0 for n in graph.get_configurable_sorted_nodes()]), \ + assert all([len(n.candidates_quantization_cfg) > 0 + for n in graph.get_configurable_sorted_nodes(graph.fw_info)]), \ "All configurable nodes in graph should have at least one candidate configuration in mixed precision mode" Logger.info(f'Set bit widths from configuration: {bit_widths_config}') # Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate). - sorted_nodes_names = graph.get_configurable_sorted_nodes_names() - for node in graph.nodes: # set a specific node qc for each node final weights qc + sorted_nodes_names = graph.get_configurable_sorted_nodes_names(graph.fw_info) + for node in graph.nodes: # set a specific node qc for each node final qc # If it's reused, take the configuration that the base node has node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2]) if node_name in sorted_nodes_names: # only configurable nodes are in this list node_index_in_graph = sorted_nodes_names.index(node_name) _set_node_final_qc(bit_widths_config, node, - node_index_in_graph) + node_index_in_graph, + graph.fw_info) else: if node.is_activation_quantization_enabled(): # If we are here, this means that we are in weights-only mixed-precision # (i.e., activations are quantized with fixed bitwidth or not quantized) - # and that this node doesn't have weights to quantize + # and that this node doesn't have kernel to quantize + # (since only the kernel is quantized in mixed precision). assert len(node.candidates_quantization_cfg) > 0, \ "Node need to have at least one quantization configuration in order to quantize its activation" node.final_activation_quantization_cfg = copy.deepcopy(node.candidates_quantization_cfg[0].activation_quantization_cfg) - if node.is_weights_quantization_enabled(): + + if node.has_any_weight_attr_to_quantize(): # If we are here, this means that we are in activation-only mixed-precision - # (i.e., weights are quantized with fixed bitwidth or not quantized) - # and that this node doesn't have activations to quantize + # (i.e., kernel is quantized with fixed bitwidth or not quantized) + # and that this node doesn't have activations to quantize. assert len(node.candidates_quantization_cfg) > 0, \ "Node need to have at least one quantization configuration in order to quantize its activation" - node.final_weights_quantization_cfg = copy.deepcopy(node.candidates_quantization_cfg[0].weights_quantization_cfg) + node.final_weights_quantization_cfg = ( + copy.deepcopy(node.candidates_quantization_cfg[0].weights_quantization_cfg)) # When working in non-mixed-precision mode, there's only one bitwidth, and we simply set the # only candidate of the node as its final weight and activation quantization configuration. @@ -79,7 +84,8 @@ def set_bit_widths(mixed_precision_enable: bool, def _get_node_qc_by_bit_widths(node: BaseNode, bit_width_cfg: List[int], - node_index_in_graph: int) -> Any: + node_index_in_graph: int, + fw_info) -> Any: """ Get the node's quantization configuration that matches to the bit width index as in the MP configuration bit_width_cfg. @@ -89,23 +95,35 @@ def _get_node_qc_by_bit_widths(node: BaseNode, node: Node to get its quantization configuration candidate. bit_width_cfg: Configuration which determines the node's desired bit width. node_index_in_graph: Index of the node in the bit_width_cfg. + fw_info: Information relevant to a specific framework about how layers should be quantized. Returns: Node quantization configuration if it was found, or None otherwise. """ + # only the weights kernel attribute is quantized in weights mixed precision at the moment + kernel_attr = fw_info.get_kernel_op_attributes(node.type) - if node.is_weights_quantization_enabled() or node.is_activation_quantization_enabled(): + if node.is_activation_quantization_enabled(): bit_index_in_cfg = bit_width_cfg[node_index_in_graph] qc = node.candidates_quantization_cfg[bit_index_in_cfg] + return qc + elif kernel_attr is not None: + if node.is_weights_quantization_enabled(kernel_attr[0]): + bit_index_in_cfg = bit_width_cfg[node_index_in_graph] + qc = node.candidates_quantization_cfg[bit_index_in_cfg] + + return qc + Logger.critical(f'Node {node.name} quantization configuration from configuration file' # pragma: no cover f' was not found in candidates configurations.') def _set_node_final_qc(bit_width_cfg: List[int], node: BaseNode, - node_index_in_graph: int): + node_index_in_graph: int, + fw_info): """ Get the node's quantization configuration that matches to the bit width index as in the MP configuration bit_width_cfg, and use it to finalize the node's @@ -116,11 +134,13 @@ def _set_node_final_qc(bit_width_cfg: List[int], bit_width_cfg: Configuration which determines the node's desired bit width. node: Node to set its node quantization configuration. node_index_in_graph: Index of the node in the bit_width_cfg. + fw_info: Information relevant to a specific framework about how layers should be quantized. """ node_qc = _get_node_qc_by_bit_widths(node, bit_width_cfg, - node_index_in_graph) + node_index_in_graph, + fw_info) if node_qc is None: Logger.critical(f'Node {node.name} quantization configuration from configuration file' # pragma: no cover diff --git a/model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py b/model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py index 5da170831..e455c8b99 100644 --- a/model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +++ b/model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py @@ -20,29 +20,46 @@ CandidateNodeQuantizationConfig -def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig]): +def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig], + kernel_attr: str = None) -> bool: """ Make sure the candidates configurations arrives in descending order. Args: node_q_cfg: Quantization configuration candidates of the node that generated the layer that will use this quantizer. + kernel_attr: Kernel attribute name to verify its candidates order, if this is a weights configurable node. - Returns: + Returns: True if the candidates are ordered in descending order for the kernel attribute bit-width first + (if this is a weights configurable node) and activation bit-width as a secondary order. """ - curmax = (np.inf, np.inf) - n_candidate_bits = [(x.weights_quantization_cfg.weights_n_bits, x.activation_quantization_cfg.activation_n_bits) - for x in node_q_cfg] - for candidate_bits in n_candidate_bits: - assert candidate_bits < curmax, f"Node's quantization configuration candidates should arrive in " \ - f"descending order of (weights_nbits, activation_nbits)." - curmax = candidate_bits + + if kernel_attr is not None: + curmax = (np.inf, np.inf) + n_candidate_bits = [(x.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits, + x.activation_quantization_cfg.activation_n_bits) + for x in node_q_cfg] + for candidate_bits in n_candidate_bits: + assert candidate_bits < curmax, f"Node's quantization configuration candidates should arrive in " \ + f"descending order of (weights_nbits, activation_nbits)." + curmax = candidate_bits + else: + # The candidates are only activation configurable + curmax = np.inf + n_candidate_bits = [x.activation_quantization_cfg.activation_n_bits for x in node_q_cfg] + for candidate_bits in n_candidate_bits: + assert candidate_bits < curmax, f"Node's quantization configuration candidates should arrive in " \ + f"descending order of activation_nbits." + curmax = candidate_bits + + return True def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig], float_weights: Any, - fw_tensor_convert_func: Callable) -> List: + fw_tensor_convert_func: Callable, + kernel_attr: str) -> List: """ Initilizes quantized weights tensors according to the given quantization configuration candidates. @@ -51,6 +68,7 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig], use this quantizer. float_weights: A tensor of the layer's weights. fw_tensor_convert_func: A function that converts a tensor to a framework specific tensor type. + kernel_attr: The kernel attribute name of the node. Only layers with kernel op can be configured. Returns: A list with the quantized weights for each candidate. @@ -58,13 +76,14 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig], quantized_weights = [] for qc in node_q_cfg: - qc_weights = qc.weights_quantization_cfg - q_weight = qc_weights.weights_quantization_fn(float_weights, - qc_weights.weights_n_bits, - True, - qc_weights.weights_quantization_params, - qc_weights.weights_per_channel_threshold, - qc_weights.weights_channels_axis) + qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr) + q_weight = qc_weights_attr.weights_quantization_fn(float_weights, + qc_weights_attr.weights_n_bits, + True, + qc_weights_attr.weights_quantization_params, + qc_weights_attr.weights_per_channel_threshold, + qc_weights_attr.weights_channels_axis[ + 0]) # output channel axis quantized_weights.append(fw_tensor_convert_func(q_weight)) diff --git a/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py b/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py index d25eff0e4..d466996c6 100644 --- a/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +++ b/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py @@ -91,13 +91,20 @@ def compute_nodes_weights_params(graph: Graph, fw_info: FrameworkInfo) -> np.nda weights_params = [] for n in graph.nodes: - if n.has_weights_quantization_enabled_candidate() and not n.reuse: - node_num_weights_params = 0 - for attr in fw_info.get_kernel_op_attributes(n.type): - if attr is not None: - node_num_weights_params += n.get_weights_by_keys(attr).flatten().shape[0] - - weights_params.append(node_num_weights_params) + # TODO: when enabling multiple attribute quantization by default (currently, + # only kernel quantization is enabled) we should include other attributes memory in the sum of all + # weights memory. + # When implementing this, we should just go over all attributes in the node instead of counting only kernels. + kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] + if kernel_attr is not None and not n.reuse: + kernel_candidates = n.get_all_weights_attr_candidates(kernel_attr) + if len(kernel_candidates) > 0 and any([c.enable_weights_quantization for c in kernel_candidates]): + node_num_weights_params = 0 + for attr in fw_info.get_kernel_op_attributes(n.type): + if attr is not None: + node_num_weights_params += n.get_weights_by_keys(attr).flatten().shape[0] + + weights_params.append(node_num_weights_params) return np.array(weights_params) @@ -142,7 +149,7 @@ def compute_total_bops(graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkI # Go over all configurable nodes that have kernels. for n in graph.get_topo_sorted_nodes(): - if n.has_weights_to_quantize(fw_info): + if n.has_kernel_weight_to_quantize(fw_info): # If node doesn't have weights then its MAC count is 0, and we shouldn't consider it in the BOPS count. incoming_edges = graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX) assert len(incoming_edges) == 1, f"Can't compute BOPS metric for node {n.name} with multiple inputs." diff --git a/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py b/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py index 3d09e2a60..8ec7b0e07 100644 --- a/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py +++ b/model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py @@ -48,27 +48,36 @@ def weights_size_kpi(mp_cfg: List[int], """ weights_memory = [] - mp_nodes = graph.get_configurable_sorted_nodes_names() - weights_mp_nodes = [n.name for n in graph.get_sorted_weights_configurable_nodes()] + mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) + weights_mp_nodes = [n.name for n in graph.get_sorted_weights_configurable_nodes(fw_info)] if len(mp_cfg) == 0: # Computing non-configurable nodes KPI + # TODO: when enabling multiple attribute quantization by default (currently, + # only kernel quantization is enabled) we should include other attributes memory in the sum of all + # weights memory (when quantized to their default 8-bit, non-configurable). + # When implementing this, we should just go over all attributes in the node instead of counting only kernels. for n in graph.nodes: + kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] + if kernel_attr is None: + continue non_configurable_node = n.name not in weights_mp_nodes \ - and n.has_weights_quantization_enabled_candidate() \ and not n.reuse \ - and n.is_all_weights_candidates_equal() + and n.is_all_weights_candidates_equal(kernel_attr) if non_configurable_node: - node_nbits = n.candidates_quantization_cfg[0].weights_quantization_cfg.weights_n_bits + node_nbits = (n.candidates_quantization_cfg[0].weights_quantization_cfg + .get_attr_config(kernel_attr).weights_n_bits) node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_nbits, fw_info) weights_memory.append(node_weights_memory_in_bytes) else: # Go over configurable all nodes that should be taken into consideration when computing the weights KPI. - for n in graph.get_sorted_weights_configurable_nodes(): + for n in graph.get_sorted_weights_configurable_nodes(fw_info): + # Only nodes with kernel op can be considered configurable + kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] node_idx = mp_nodes.index(n.name) node_qc = n.candidates_quantization_cfg[mp_cfg[node_idx]] - node_nbits = node_qc.weights_quantization_cfg.weights_n_bits + node_nbits = node_qc.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_nbits, fw_info) @@ -98,7 +107,7 @@ def activation_output_size_kpi(mp_cfg: List[int], """ activation_memory = [] - mp_nodes = graph.get_configurable_sorted_nodes_names() + mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) activation_mp_nodes = [n.name for n in graph.get_sorted_activation_configurable_nodes()] if len(mp_cfg) == 0: @@ -147,7 +156,7 @@ def total_weights_activation_kpi(mp_cfg: List[int], """ weights_activation_memory = [] - weights_mp_nodes = [n.name for n in graph.get_sorted_weights_configurable_nodes()] + weights_mp_nodes = [n.name for n in graph.get_sorted_weights_configurable_nodes(fw_info)] activation_mp_nodes = [n.name for n in graph.get_sorted_activation_configurable_nodes()] if len(mp_cfg) == 0: @@ -158,15 +167,19 @@ def total_weights_activation_kpi(mp_cfg: List[int], node_weights_memory_in_bytes, node_activation_memory_in_bytes = 0, 0 # Non-configurable Weights - is_non_configurable_weights = n.name not in weights_mp_nodes and \ - n.has_weights_quantization_enabled_candidate() and \ - n.is_all_weights_candidates_equal() and \ - not n.reuse - - if is_non_configurable_weights: - node_nbits = n.candidates_quantization_cfg[0].weights_quantization_cfg.weights_n_bits - node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_nbits, fw_info) - non_configurable = True + # TODO: currently considering only kernel attributes in weights KPI. When enabling multi-attribute + # quantization we need to modify this method to count all attributes. + kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] + if kernel_attr is not None: + is_non_configurable_weights = n.name not in weights_mp_nodes and \ + n.is_all_weights_candidates_equal(kernel_attr) and \ + not n.reuse + + if is_non_configurable_weights: + node_nbits = (n.candidates_quantization_cfg[0].weights_quantization_cfg + .get_attr_config(kernel_attr).weights_n_bits) + node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_nbits, fw_info) + non_configurable = True # Non-configurable Activation is_non_configurable_activation = n.name not in activation_mp_nodes and \ @@ -184,17 +197,22 @@ def total_weights_activation_kpi(mp_cfg: List[int], else: # Go over all nodes that should be taken into consideration when computing the weights or # activation KPI (all configurable nodes). - for node_idx, n in enumerate(graph.get_configurable_sorted_nodes()): + for node_idx, n in enumerate(graph.get_configurable_sorted_nodes(fw_info)): + # TODO: currently considering only kernel attributes in weights KPI. When enabling multi-attribute + # quantization we need to modify this method to count all attributes. + node_qc = n.candidates_quantization_cfg[mp_cfg[node_idx]] - node_weights_nbits = node_qc.weights_quantization_cfg.weights_n_bits - node_activation_nbits = node_qc.activation_quantization_cfg.activation_n_bits # Compute node's weights memory (if no weights to quantize then set to 0) node_weights_memory_in_bytes = 0 - if n.is_weights_quantization_enabled() and not n.is_all_weights_candidates_equal(): - node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_weights_nbits, fw_info) + kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] + if kernel_attr is not None: + if n.is_weights_quantization_enabled(kernel_attr) and not n.is_all_weights_candidates_equal(kernel_attr): + node_weights_nbits = node_qc.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits + node_weights_memory_in_bytes = _compute_node_weights_memory(n, node_weights_nbits, fw_info) # Compute node's activation memory (if node's activation are not being quantized then set to 0) + node_activation_nbits = node_qc.activation_quantization_cfg.activation_n_bits node_activation_memory_in_bytes = 0 if n.is_activation_quantization_enabled() and not n.is_all_activation_candidates_equal(): node_activation_memory_in_bytes = _compute_node_activation_memory(n, node_activation_nbits) @@ -237,7 +255,7 @@ def bops_kpi(mp_cfg: List[int], virtual_bops_nodes = [n for n in graph.get_topo_sorted_nodes() if isinstance(n, VirtualActivationWeightsNode)] - mp_nodes = graph.get_configurable_sorted_nodes_names() + mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) bops = [n.get_bops_count(fw_impl, fw_info, candidate_idx=_get_node_cfg_idx(n, mp_cfg, mp_nodes)) for n in virtual_bops_nodes] return np.array(bops) @@ -261,12 +279,12 @@ def _bops_kpi(mp_cfg: List[int], """ - mp_nodes = graph.get_configurable_sorted_nodes_names() + mp_nodes = graph.get_configurable_sorted_nodes_names(fw_info) # Go over all nodes that should be taken into consideration when computing the BOPS KPI. bops = [] for n in graph.get_topo_sorted_nodes(): - if n.has_weights_to_quantize(fw_info): + if n.has_kernel_weight_to_quantize(fw_info): # If node doesn't have weights then its MAC count is 0, and we shouldn't consider it in the BOPS count. incoming_edges = graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX) if len(incoming_edges) != 1: @@ -283,8 +301,9 @@ def _bops_kpi(mp_cfg: List[int], node_mac = fw_impl.get_node_mac_operations(n, fw_info) node_qc = n.candidates_quantization_cfg[_get_node_cfg_idx(n, mp_cfg, mp_nodes)] - node_weights_nbits = node_qc.weights_quantization_cfg.weights_n_bits if \ - node_qc.weights_quantization_cfg.enable_weights_quantization else FLOAT_BITWIDTH + kenrel_node_qc = node_qc.weights_quantization_cfg.get_attr_config(fw_info.get_kernel_op_attributes(n.type)[0]) + node_weights_nbits = kenrel_node_qc.weights_n_bits if \ + kenrel_node_qc.enable_weights_quantization else FLOAT_BITWIDTH input_activation_nbits = input_activation_node_cfg.activation_quantization_cfg.activation_n_bits if \ input_activation_node_cfg.activation_quantization_cfg.enable_activation_quantization else FLOAT_BITWIDTH diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py index 5a1f68ce6..7620a1404 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py @@ -68,8 +68,8 @@ def __init__(self, self.compute_kpi_functions = kpi_functions self.target_kpi = target_kpi - self.min_kpi_config = self.graph.get_min_candidates_config() - self.max_kpi_config = self.graph.get_max_candidates_config() + self.min_kpi_config = self.graph.get_min_candidates_config(fw_info) + self.max_kpi_config = self.graph.get_max_candidates_config(fw_info) self.min_kpi = self.compute_min_kpis() self.non_conf_kpi_dict = self._non_configurable_nodes_kpi() @@ -86,7 +86,7 @@ def get_search_space(self) -> Dict[int, List[int]]: """ indices_mapping = {} - nodes_to_configure = self.graph.get_configurable_sorted_nodes() + nodes_to_configure = self.graph.get_configurable_sorted_nodes(self.fw_info) for idx, n in enumerate(nodes_to_configure): # For each node, get all possible bitwidth indices for it # (which is a list from 0 to the length of the candidates mp_config list of the node). @@ -142,7 +142,7 @@ def compute_kpi_matrix(self, target: KPITarget) -> np.ndarray: """ assert isinstance(target, KPITarget), f"{target} is not a valid KPI target" - configurable_sorted_nodes = self.graph.get_configurable_sorted_nodes() + configurable_sorted_nodes = self.graph.get_configurable_sorted_nodes(self.fw_info) kpi_matrix = [] for c, c_n in enumerate(configurable_sorted_nodes): @@ -336,9 +336,10 @@ def __init__(self, virtual_graph: Graph, original_graph: Graph): self.virtual_graph = virtual_graph self.original_graph = original_graph + self.fw_info = original_graph.fw_info - self.virtual_sorted_nodes_names = self.virtual_graph.get_configurable_sorted_nodes_names() - self.origin_sorted_conf_nodes_names = self.original_graph.get_configurable_sorted_nodes_names() + self.virtual_sorted_nodes_names = self.virtual_graph.get_configurable_sorted_nodes_names(self.fw_info) + self.origin_sorted_conf_nodes_names = self.original_graph.get_configurable_sorted_nodes_names(self.fw_info) self.origin_node_idx_to_cfg = {} @@ -378,19 +379,19 @@ def reconstruct_config_from_virtual_graph(self, "set of nodes.") # pragma: no cover updated_virtual_nodes = \ - [(idx, self.virtual_graph.get_configurable_sorted_nodes()[idx]) for idx in changed_virtual_nodes_idx] + [(idx, self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)[idx]) for idx in changed_virtual_nodes_idx] # Iterating only over the virtual nodes that have updated config for virtual_node_idx, n in updated_virtual_nodes: self.reconstruct_node_config(n, virtual_mp_cfg, virtual_node_idx) # Updating reconstructed config for all other nodes based on provided base_config - original_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes() + original_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info) for i in range(len(original_base_config)): if i not in list(self.origin_node_idx_to_cfg.keys()): self.update_config_at_original_idx(n=original_sorted_conf_nodes[i], origin_cfg_idx=original_base_config[i]) else: # Reconstruct entire config - for virtual_node_idx, n in enumerate(self.virtual_graph.get_configurable_sorted_nodes()): + for virtual_node_idx, n in enumerate(self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)): self.reconstruct_node_config(n, virtual_mp_cfg, virtual_node_idx) res_config = [self.origin_node_idx_to_cfg[key] for key in sorted(self.origin_node_idx_to_cfg.keys())] @@ -467,10 +468,12 @@ def retrieve_weights_only_config(self, weights_node: BaseNode, virtual_node: Bas if weights_node.name in self.origin_sorted_conf_nodes_names: # It is possible that the original weights node is not configurable, # in this case we don't need to retrieve its bit-width config - weights_bitwidth = virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg.weights_n_bits + kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0] + weights_bitwidth = (virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg + .get_attr_config(kernel_attr).weights_n_bits) origin_cfg_idx = [i for i, c in enumerate(weights_node.candidates_quantization_cfg) if - c.weights_quantization_cfg.weights_n_bits == weights_bitwidth] + c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth] self.update_config_at_original_idx(weights_node, origin_cfg_idx[0]) @@ -519,11 +522,14 @@ def retrieve_activation_weights_config(self, activation_bitwidth = activation_node.candidates_quantization_cfg[virtual_mp_cfg[ self.virtual_sorted_nodes_names.index(activation_node.name)]].activation_quantization_cfg.activation_n_bits - weights_bitwidth = virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg.weights_n_bits + kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0] + + weights_bitwidth = (virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg + .get_attr_config(kernel_attr).weights_n_bits) origin_cfg_idx = [i for i, c in enumerate(weights_node.origin_node.candidates_quantization_cfg) if - c.weights_quantization_cfg.weights_n_bits == weights_bitwidth and + c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth and c.activation_quantization_cfg.activation_n_bits == activation_bitwidth] self.update_config_at_original_idx(weights_node.origin_node, origin_cfg_idx[0]) @@ -547,14 +553,17 @@ def retrieve_weights_activation_config(self, virtual_mp_cfg: The virtual graph's chosen mp config. """ - weights_bitwidth = weights_node.candidates_quantization_cfg[virtual_mp_cfg[ - self.virtual_sorted_nodes_names.index(weights_node.name)]].weights_quantization_cfg.weights_n_bits + kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0] + + weights_bitwidth = (weights_node.candidates_quantization_cfg[virtual_mp_cfg[ + self.virtual_sorted_nodes_names.index(weights_node.name)]] + .weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits) activation_bitwidth = virtual_node.candidates_quantization_cfg[ virtual_cfg_idx].activation_quantization_cfg.activation_n_bits origin_cfg_idx = [i for i, c in enumerate(activation_node.origin_node.candidates_quantization_cfg) if - c.weights_quantization_cfg.weights_n_bits == weights_bitwidth and + c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth and c.activation_quantization_cfg.activation_n_bits == activation_bitwidth] self.update_config_at_original_idx(activation_node.origin_node, origin_cfg_idx[0]) @@ -624,8 +633,9 @@ def get_weights_for_split_activation(self, weights_node = matching_weights_node[0] if isinstance(weights_node, VirtualActivationWeightsNode): - if weights_node.original_weights_node.is_weights_quantization_enabled() and not \ - weights_node.original_weights_node.is_all_weights_candidates_equal(): + kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0] + if weights_node.original_weights_node.is_weights_quantization_enabled(kernel_attr) and not \ + weights_node.original_weights_node.is_all_weights_candidates_equal(kernel_attr): assert weights_node.name in self.virtual_sorted_nodes_names # Sanity check # The original node is both weights and activation configurable self.retrieve_weights_activation_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg) diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py index b11b4ad94..63a47ea7b 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py @@ -82,7 +82,7 @@ def __init__(self, f" an HessianInfoService object must be provided but is {hessian_info_service}") self.hessian_info_service = hessian_info_service - self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names() + self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names(self.fw_info) # Get interest points and output points set for distance measurement and set other helper datasets # We define a separate set of output nodes of the model for the purpose of sensitivity computation. @@ -486,7 +486,7 @@ def get_mp_interest_points(graph: Graph, def get_output_nodes_for_metric(graph: Graph) -> List[BaseNode]: """ - Returns a list of output nodes that are also quantized (either weights or activation) + Returns a list of output nodes that are also quantized (either kernel weights attribute or activation) to be used as a set of output points in the distance metric computation. Args: @@ -495,8 +495,11 @@ def get_output_nodes_for_metric(graph: Graph) -> List[BaseNode]: Returns: A list of output nodes. """ - return [n.node for n in graph.get_outputs() if (n.node.is_weights_quantization_enabled() or - n.node.is_activation_quantization_enabled())] + + return [n.node for n in graph.get_outputs() + if (graph.fw_info.is_kernel_op(n.node.type) and + n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or + n.node.is_activation_quantization_enabled()] def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]: diff --git a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py index 7392723ce..bab966ea3 100644 --- a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +++ b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py @@ -64,8 +64,13 @@ def greedy_solution_refinement_procedure(mp_solution: List[int], # layer has max config in the given solution, nothing to optimize continue - node_candidates = search_manager.graph.get_configurable_sorted_nodes()[node_idx].candidates_quantization_cfg - valid_candidates = _get_valid_candidates_indices(node_candidates, new_solution[node_idx]) + current_node = search_manager.graph.get_configurable_sorted_nodes(search_manager.fw_info)[node_idx] + node_candidates = current_node.candidates_quantization_cfg + + # only weights kernel attribute is quantized with weights mixed precision + kernel_attr = search_manager.fw_info.get_kernel_op_attributes(current_node) + kernel_attr = None if kernel_attr is None else kernel_attr[0] + valid_candidates = _get_valid_candidates_indices(node_candidates, new_solution[node_idx], kernel_attr) # Create a list of KPIs for the valid candidates. updated_kpis = [] @@ -75,23 +80,22 @@ def greedy_solution_refinement_procedure(mp_solution: List[int], updated_kpis.append(node_updated_kpis) # filter out new configs that don't hold the KPI restrictions - node_filtered_kpis = [(node_idx, kpis) for node_idx, kpis in zip(valid_candidates,updated_kpis) if - target_kpi.holds_constraints(kpis)] + node_filtered_kpis = [(node_idx, kpis) for node_idx, kpis in zip(valid_candidates, updated_kpis) if + target_kpi.holds_constraints(kpis)] if len(node_filtered_kpis) > 0: sorted_by_kpi = sorted(node_filtered_kpis, key=lambda node_kpis: (node_kpis[1].total_memory, - node_kpis[1].weights_memory, - node_kpis[1].activation_memory)) + node_kpis[1].weights_memory, + node_kpis[1].activation_memory)) nodes_kpis[node_idx] = sorted_by_kpi[0][1] nodes_next_candidate[node_idx] = sorted_by_kpi[0][0] - if len(nodes_kpis) > 0: # filter out new configs that don't hold the KPI restrictions node_filtered_kpis = [(node_idx, kpis) for node_idx, kpis in nodes_kpis.items()] sorted_by_kpi = sorted(node_filtered_kpis, key=lambda node_kpis: (node_kpis[1].total_memory, - node_kpis[1].weights_memory, - node_kpis[1].activation_memory)) + node_kpis[1].weights_memory, + node_kpis[1].activation_memory)) node_idx_to_upgrade = sorted_by_kpi[0][0] new_solution[node_idx_to_upgrade] = nodes_next_candidate[node_idx_to_upgrade] @@ -102,22 +106,38 @@ def greedy_solution_refinement_procedure(mp_solution: List[int], def _get_valid_candidates_indices(node_candidates: List[CandidateNodeQuantizationConfig], - current_chosen_index: int) -> List[int]: + current_chosen_index: int, + kernel_attr: str = None) -> List[int]: """ Find node's valid candidates to try and improve the node's MP chosen candidate. - Valid indices are indices of candidates that have higher number of bits for both weights and activations. + Valid indices are indices of candidates that have higher number of bits for both weights and activations + (if they are quantized in this node). Args: node_candidates: Candidates of the node. current_chosen_index: Current index in MP configuration of the node. + kernel_attr: The name of the kernel attribute on the node, otherwise None. Returns: List of indices of valid candidates. """ - current_candidate = node_candidates[current_chosen_index] - weights_num_bits = current_candidate.weights_quantization_cfg.weights_n_bits - activation_num_bits = current_candidate.activation_quantization_cfg.activation_n_bits - # Filter candidates that have higher bit-width for both weights and activations (except for the current index). - return [i for i, c in enumerate(node_candidates) if c.activation_quantization_cfg.activation_n_bits >= activation_num_bits and c.weights_quantization_cfg.weights_n_bits >= weights_num_bits and not (c.activation_quantization_cfg.activation_n_bits == activation_num_bits and c.weights_quantization_cfg.weights_n_bits == weights_num_bits)] + if kernel_attr is None: + # In this node we only quantize activation, so no need to check weights number of bits + activation_num_bits = current_candidate.activation_quantization_cfg.activation_n_bits + + # Filter candidates that have higher bit-width for activations + return [i for i, c in enumerate(node_candidates) if + c.activation_quantization_cfg.activation_n_bits >= activation_num_bits + and not (c.activation_quantization_cfg.activation_n_bits == activation_num_bits)] + else: + weights_num_bits = current_candidate.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits + activation_num_bits = current_candidate.activation_quantization_cfg.activation_n_bits + + # Filter candidates that have higher bit-width for both weights and activations (except for the current index). + return [i for i, c in enumerate(node_candidates) if + c.activation_quantization_cfg.activation_n_bits >= activation_num_bits + and c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits >= weights_num_bits + and not (c.activation_quantization_cfg.activation_n_bits == activation_num_bits + and c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_num_bits)] diff --git a/model_compression_toolkit/core/common/model_collector.py b/model_compression_toolkit/core/common/model_collector.py index 27548324e..f925aae50 100644 --- a/model_compression_toolkit/core/common/model_collector.py +++ b/model_compression_toolkit/core/common/model_collector.py @@ -100,9 +100,10 @@ def __init__(self, graph: Graph, # Assign statisitcs collectors to nodes for n in graph.get_topo_sorted_nodes(): sc = create_stats_collector_for_node(n, fw_info=fw_info) # Get static collector for the node - # If we use bias correction, and the node has coefficients to quantize, we need to make sure + # If we use bias correction, and the node has kernel weights to quantize, we need to make sure # its previous nodes' tensors are consistent with this node. - if qc.weights_bias_correction and n.is_weights_quantization_enabled(): + kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] + if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr): for ie in graph.incoming_edges(n): input_node = ie.source_node create_tensor2node(graph, diff --git a/model_compression_toolkit/core/common/network_editors/actions.py b/model_compression_toolkit/core/common/network_editors/actions.py index f3d11df28..42db64337 100644 --- a/model_compression_toolkit/core/common/network_editors/actions.py +++ b/model_compression_toolkit/core/common/network_editors/actions.py @@ -17,6 +17,7 @@ from collections import namedtuple from typing import Callable +from mct_quantizers import QuantizationMethod from model_compression_toolkit.core.common import Graph from model_compression_toolkit.logger import Logger @@ -37,11 +38,12 @@ class EditRule(_EditRule): and the action is applied on these nodes during the quantization process. Examples: - Create an EditRule to quantize all Conv2D wights using 9 bits: + Create an EditRule to quantize all Conv2D kernel attribute weights using 9 bits: >>> import model_compression_toolkit as mct + >>> from model_compression_toolkit.core.keras.constants import KERNEL >>> from tensorflow.keras.layers import Conv2D - >>> er_list = [mct.network_editor.EditRule(filter=mct.network_editor.NodeTypeFilter(Conv2D), action=mct.network_editor.ChangeCandidatesWeightsQuantConfigAttr(weights_n_bits=9))] + >>> er_list = [mct.network_editor.EditRule(filter=mct.network_editor.NodeTypeFilter(Conv2D), action=mct.network_editor.ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, weights_n_bits=9))] Then the rules list can be passed to :func:`~model_compression_toolkit.keras_post_training_quantization` to modify the network during the quantization process. @@ -84,12 +86,14 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction): Change attributes in a layer's weights quantization configuration candidates. """ - def __init__(self, **kwargs): + def __init__(self, attr_name: str = None, **kwargs): """ Args: + attr_name: The weights attribute's name to set the weights quantization params function for. kwargs: Dictionary of attr_name and attr_value to change layer's weights quantization configuration candidates. """ self.kwargs = kwargs + self.attr_name = attr_name def apply(self, node: BaseNode, graph, fw_info): """ @@ -103,9 +107,11 @@ def apply(self, node: BaseNode, graph, fw_info): Returns: The node after its weights' quantization config candidates have been modified. """ + for nqc in node.candidates_quantization_cfg: - for attr_name, attr_value in self.kwargs.items(): - nqc.weights_quantization_cfg.set_quant_config_attr(attr_name, attr_value) + for parameter_name, parameter_value in self.kwargs.items(): + nqc.weights_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value, + attr_name=self.attr_name) class ChangeFinalWeightsQuantConfigAttr(BaseAction): @@ -113,17 +119,20 @@ class ChangeFinalWeightsQuantConfigAttr(BaseAction): Change attributes in a layer's final weights quantization config. """ - def __init__(self, **kwargs): + def __init__(self, attr_name: str = None, **kwargs): """ Args: + attr_name: The weights attribute's name to set the weights quantization params function for. kwargs: Dictionary of attr_name and attr_value to change layer's final weights quantization config. """ self.kwargs = kwargs + self.attr_name = attr_name def apply(self, node: BaseNode, graph, fw_info): if node.final_weights_quantization_cfg is not None: - for attr_name, attr_value in self.kwargs.items(): - node.final_weights_quantization_cfg.set_quant_config_attr(attr_name, attr_value) + for parameter_name, parameter_value in self.kwargs.items(): + node.final_weights_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value, + self.attr_name) class ChangeCandidatesActivationQuantConfigAttr(BaseAction): @@ -151,8 +160,8 @@ def apply(self, node: BaseNode, graph, fw_info): The node after its activation quantization configuration candidates have been modified. """ for nqc in node.candidates_quantization_cfg: - for attr_name, attr_value in self.kwargs.items(): - nqc.activation_quantization_cfg.set_quant_config_attr(attr_name, attr_value) + for parameter_name, parameter_value in self.kwargs.items(): + nqc.activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value) class ChangeFinalActivationQuantConfigAttr(BaseAction): @@ -169,8 +178,8 @@ def __init__(self, **kwargs): def apply(self, node: BaseNode, graph, fw_info): if node.final_activation_quantization_cfg is not None: - for attr_name, attr_value in self.kwargs.items(): - node.final_activation_quantization_cfg.set_quant_config_attr(attr_name, attr_value) + for parameter_name, parameter_value in self.kwargs.items(): + node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value) class ChangeQuantizationParamFunction(BaseAction): @@ -178,16 +187,21 @@ class ChangeQuantizationParamFunction(BaseAction): Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function. """ - def __init__(self, activation_quantization_params_fn=None, weights_quantization_params_fn=None): + def __init__(self, + attr_name: str = None, + activation_quantization_params_fn: Callable = None, + weights_quantization_params_fn: Callable = None): """ Init a ChangeQuantizationParamFunction object. Args: + attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params). activation_quantization_params_fn: a params function for a node's activations. weights_quantization_params_fn: a params function for a node's weights. """ self.activation_quantization_params_fn = activation_quantization_params_fn self.weights_quantization_params_fn = weights_quantization_params_fn + self.attr_name = attr_name def apply(self, node: BaseNode, graph, fw_info): """ @@ -207,7 +221,8 @@ def apply(self, node: BaseNode, graph, fw_info): nqc.activation_quantization_cfg.set_activation_quantization_params_fn( self.activation_quantization_params_fn) if self.weights_quantization_params_fn is not None: - nqc.weights_quantization_cfg.set_weights_quantization_params_fn(self.weights_quantization_params_fn) + (nqc.weights_quantization_cfg.get_attr_config(self.attr_name) + .set_weights_quantization_params_fn(self.weights_quantization_params_fn)) class ChangeFinalActivationQuantizationMethod(BaseAction): @@ -301,15 +316,17 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction): Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer function. """ - def __init__(self, weights_quantization_method=None): + def __init__(self, attr_name: str, weights_quantization_method=None): """ Init a ChangeFinalWeightsQuantizationMethod object. Args: + attr_name: The weights attribute's name to set the weights quantization method for. weights_quantization_method: a quantization method for a node's weights. """ self.weights_quantization_method = weights_quantization_method + self.attr_name = attr_name def apply(self, node: BaseNode, graph, fw_info): """ @@ -329,15 +346,18 @@ def apply(self, node: BaseNode, graph, fw_info): weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method) - node.final_weights_quantization_cfg.set_weights_quantization_params_fn(weights_quantization_params_fn) + (node.final_weights_quantization_cfg.get_attr_config(self.attr_name) + .set_weights_quantization_params_fn(weights_quantization_params_fn)) weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method) if weights_quantization_fn is None: raise Exception('Unknown quantization method for weights') # pragma: no cover - node.final_weights_quantization_cfg.set_weights_quantization_fn(weights_quantization_fn) - node.final_weights_quantization_cfg.weights_quantization_method = self.weights_quantization_method + (node.final_weights_quantization_cfg.get_attr_config(self.attr_name) + .set_weights_quantization_fn(weights_quantization_fn)) + node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \ + self.weights_quantization_method class ChangeCandidatesWeightsQuantizationMethod(BaseAction): @@ -345,14 +365,16 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction): Class ChangeCandidatesWeightsQuantizationMethod to change a node's weights quantizer function. """ - def __init__(self, weights_quantization_method=None): + def __init__(self, attr_name: str, weights_quantization_method: QuantizationMethod = None): """ Init a ChangeCandidatesWeightsQuantizationMethod object. Args: weights_quantization_method: a quantization method for a node's weights. + attr_name: The weights attribute's name to set the weights quantization params function for. """ self.weights_quantization_method = weights_quantization_method + self.attr_name = attr_name def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo): """ @@ -373,15 +395,16 @@ def apply(self, node: BaseNode, graph: Graph, fw_info: FrameworkInfo): weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method) - qc.weights_quantization_cfg.set_weights_quantization_params_fn(weights_quantization_params_fn) + attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name) + attr_qc.set_weights_quantization_params_fn(weights_quantization_params_fn) weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method) if weights_quantization_fn is None: raise Exception('Unknown quantization method for weights') # pragma: no cover - qc.weights_quantization_cfg.set_weights_quantization_fn(weights_quantization_fn) - qc.weights_quantization_cfg.weights_quantization_method = self.weights_quantization_method + attr_qc.set_weights_quantization_fn(weights_quantization_fn) + attr_qc.weights_quantization_method = self.weights_quantization_method class ReplaceLayer(BaseAction): diff --git a/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py index b624f7f14..82164cfd8 100644 --- a/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import Callable +from typing import Callable, List, Tuple from model_compression_toolkit.core import QuantizationConfig from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \ @@ -40,10 +40,8 @@ def __init__(self, activation_quantization_fn: Callable = None, activation_quantization_params_fn: Callable = None, weights_quantization_cfg: NodeWeightsQuantizationConfig = None, - weights_quantization_fn: Callable = None, - weights_quantization_params_fn: Callable = None, - weights_channels_axis: int = None, - weights_cfg: AttributeQuantizationConfig = None): + weights_channels_axis: Tuple[int, int] = None, + node_attrs_list: List[str] = None): """ Args: @@ -53,10 +51,8 @@ def __init__(self, activation_quantization_fn: Function to use when quantizing the node's activations. activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations. weights_quantization_cfg: An option to pass a NodeWeightsQuantizationConfig to create a new config from. - weights_quantization_fn: Function to use when quantizing the node's weights. - weights_quantization_params_fn: Function to use when computing the threshold for quantizing a node's weights. - weights_channels_axis: Axis to quantize a node's kernel when quantizing per-channel. - weights_cfg: Weights attribute quantization config. + weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel. + node_attrs_list: A list of the node's weights attributes names. """ if activation_quantization_cfg is not None: @@ -74,14 +70,9 @@ def __init__(self, if weights_quantization_cfg is not None: self.weights_quantization_cfg = weights_quantization_cfg else: - if any(v is None for v in (qc, op_cfg, weights_quantization_fn, weights_quantization_params_fn, - weights_cfg)): + if any(v is None for v in (qc, op_cfg, node_attrs_list)): Logger.error("Missing some required arguments to initialize " "a node weights quantization configuration.") - self.weights_quantization_cfg = ( - NodeWeightsQuantizationConfig(qc=qc, - op_cfg=op_cfg, - weights_quantization_fn=weights_quantization_fn, - weights_quantization_params_fn=weights_quantization_params_fn, - weights_channels_axis=weights_channels_axis, - weights_cfg=weights_cfg)) + self.weights_quantization_cfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg, + weights_channels_axis=weights_channels_axis, + node_attrs_list=node_attrs_list) diff --git a/model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py b/model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py index e81c2f7ac..76ef891cc 100644 --- a/model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +++ b/model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py @@ -35,18 +35,21 @@ def filter_nodes_candidates(graph: Graph): """ nodes = list(graph.nodes) for n in nodes: - n.candidates_quantization_cfg = filter_node_candidates(node=n) + n.candidates_quantization_cfg = filter_node_candidates(node=n, fw_info=graph.fw_info) return graph -def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig]) -> List[CandidateNodeQuantizationConfig]: +def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig], + kernel_attr: str = None) -> List[CandidateNodeQuantizationConfig]: """ Filters out duplications in candidates configuration list, based on similarity in (weights_n_bits, weights_quantization_method, activation_n_bits, activation_quantization_method). + Weights quantization configuration considers only kernel attributes. Args: candidates: A list of quantization configuration candidates. + kernel_attr: The name of the node's kernel attribute if such exists. Returns: A filtered list of quantization configuration candidates. @@ -54,8 +57,12 @@ def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig]) - seen_bits_method_combinations = set() final_candidates = [] for c in candidates: - comb = (c.weights_quantization_cfg.weights_n_bits, - c.weights_quantization_cfg.weights_quantization_method, + weight_n_bits = None if kernel_attr is None else ( + c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits) + weights_quantization_method = None if kernel_attr is None else ( + c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_quantization_method) + comb = (weight_n_bits, + weights_quantization_method, c.activation_quantization_cfg.activation_n_bits, c.activation_quantization_cfg.activation_quantization_method) if comb not in seen_bits_method_combinations: @@ -65,7 +72,7 @@ def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig]) - return final_candidates -def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConfig]: +def filter_node_candidates(node: BaseNode, fw_info) -> List[CandidateNodeQuantizationConfig]: """ Updates a node's candidates configuration list. If the node's weights quantization is disabled (or it only has activations to quantize), then the updated list @@ -75,19 +82,27 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf Args: node: Node to set its quantization configurations. + fw_info: FrameworkInfo object with information about the specific framework's model. + """ filtered_candidates = copy.deepcopy(node.candidates_quantization_cfg) final_candidates = copy.deepcopy(node.candidates_quantization_cfg) + kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0] - if not node.is_weights_quantization_enabled() and not node.is_activation_quantization_enabled(): - # If both weights and activation quantization are disabled, but for some reason the node has multiple candidates - # then replace it with a single dummy candidate with default bit-width values. + if (kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr)) and not node.is_activation_quantization_enabled(): + # If activation quantization is disabled and the node doesn't have a kernel or doesn't quantize the kernel, + # but for some reason the node has multiple candidates then replace it with a single dummy candidate with + # default bit-width values. single_dummy_candidate = filtered_candidates[0] single_dummy_candidate.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH - single_dummy_candidate.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH - single_dummy_candidate.weights_quantization_cfg.weights_quantization_method = QuantizationMethod.POWER_OF_TWO single_dummy_candidate.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO + + if kernel_attr is not None: + kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(kernel_attr) + kernel_config.weights_n_bits = FLOAT_BITWIDTH + kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO + final_candidates = [single_dummy_candidate] elif not node.is_activation_quantization_enabled(): @@ -102,9 +117,9 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH c.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO - final_candidates = _filter_bit_method_dups(filtered_candidates) + final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr) - elif not node.is_weights_quantization_enabled(): + elif kernel_attr is None or not node.is_weights_quantization_enabled(kernel_attr): # Remove candidates that have duplicated activation candidates for node with disabled weights quantization. # Replacing the weights n_bits in the remained configurations with default value to prevent confusion. seen_candidates = set() @@ -113,9 +128,11 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf and not seen_candidates.add(candidate.activation_quantization_cfg)] for c in filtered_candidates: - c.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH - c.weights_quantization_cfg.weights_quantization_method = QuantizationMethod.POWER_OF_TWO + if kernel_attr is not None: + kernel_config = c.weights_quantization_cfg.get_attr_config(kernel_attr) + kernel_config.weights_n_bits = FLOAT_BITWIDTH + kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO - final_candidates = _filter_bit_method_dups(filtered_candidates) + final_candidates = _filter_bit_method_dups(filtered_candidates, kernel_attr) return final_candidates diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index b5be7735e..d72ca37e5 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -14,10 +14,11 @@ # ============================================================================== -from typing import Callable, Any +from typing import Callable, Any, List, Tuple, Union, Dict import numpy as np +from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \ get_activation_quantization_params_fn, get_weights_quantization_params_fn @@ -40,18 +41,25 @@ class BaseNodeQuantizationConfig(object): Base class for node quantization configuration """ - def set_quant_config_attr(self, attr_name, attr_value): + def set_quant_config_attr(self, parameter_name: str, parameter_value: Any, + *args: List[Any], **kwargs: Dict[str, Any]): """ - Changes a BaseNodeQuantizationConfig's attribute. + Changes a BaseNodeQuantizationConfig's parameter. + Note that arg and kwargs are only to allow clean override in the child classes. Args: - attr_name: attribute name to change. - attr_value: attribute value to change. + parameter_name: parameter name to change. + parameter_value: parameter value to change. + args: A list of additional arguments. + kwargs: A dictionary with additional key arguments. """ - if hasattr(self, attr_name): - setattr(self, attr_name, attr_value) + if hasattr(self, parameter_name): + setattr(self, parameter_name, parameter_value) + else: + Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config and " + f"was not updated!") def __repr__(self) -> str: """ @@ -228,44 +236,32 @@ def __hash__(self): self.shift_negative_threshold_recalculation)) -class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig): +class WeightsAttrQuantizationConfig: """ - Attributes for configuring the quantization of the weights of a node. + Configuration for quantizing a weights attribute of a node. """ def __init__(self, qc: QuantizationConfig, - op_cfg: OpQuantizationConfig, - weights_quantization_fn: Callable, - weights_quantization_params_fn: Callable, - weights_channels_axis: int, - weights_cfg: AttributeQuantizationConfig): + weights_attr_cfg: AttributeQuantizationConfig, + weights_channels_axis: Tuple[int, int] = None): """ Args: qc: QuantizationConfig to create the node's config from. - op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration. - weights_quantization_fn: Function to use when quantizing the node's weights. - weights_quantization_params_fn: Function to use when computing the threshold for quantizing a node's weights. - weights_channels_axis: Axis to quantize a node's kernel when quantizing per-channel. - weights_cfg: Weights attribute quantization config. + weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config. + weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None). """ - - # TODO: after refactoring to enable attributes quantization, all weights quantization arguments - # should be taken per attribute, and not from the weights config - self.weights_quantization_fn = weights_quantization_fn - self.weights_quantization_params_fn = weights_quantization_params_fn + self.weights_quantization_fn = get_weights_quantization_fn(weights_attr_cfg.weights_quantization_method) + self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method) self.weights_channels_axis = weights_channels_axis self.weights_quantization_params = {} - self.weights_quantization_method = weights_cfg.weights_quantization_method + self.weights_quantization_method = weights_attr_cfg.weights_quantization_method self.weights_error_method = qc.weights_error_method - self.weights_n_bits = weights_cfg.weights_n_bits - self.weights_bias_correction = qc.weights_bias_correction - self.weights_second_moment_correction = qc.weights_second_moment_correction - self.weights_per_channel_threshold = weights_cfg.weights_per_channel_threshold - self.enable_weights_quantization = weights_cfg.enable_weights_quantization - self.min_threshold = qc.min_threshold + self.weights_n_bits = weights_attr_cfg.weights_n_bits + self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold + self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization self.l_p_value = qc.l_p_value - self.simd_size = op_cfg.simd_size + @property @@ -287,7 +283,6 @@ def weights_error_method(self, value: QuantizationErrorMethod): self._weights_error_method = value self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_quantization_method=self.weights_quantization_method) - def set_weights_quantization_fn(self, weights_quantization_fn: Callable): """ Sets weights quantization function for the node. @@ -321,10 +316,11 @@ def set_weights_quantization_param(self, for param_name, param_value in weights_params.items(): self.weights_quantization_params[param_name] = param_value - def calculate_and_set_weights_params(self, tensor_data: np.ndarray) -> float: + def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshold: float): """ Args: tensor_data: Tensor content as Numpy array. + min_threshold: A minimal threshold to set as quantization parameter. Returns: Recalculated weights quantization params from the kernel and channel axis. @@ -336,10 +332,10 @@ def calculate_and_set_weights_params(self, tensor_data: np.ndarray) -> float: p=self.l_p_value, n_bits=self.weights_n_bits, per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None, - channel_axis=self.weights_channels_axis, - min_threshold=self.min_threshold)) + channel_axis=self.weights_channels_axis[0], # output channel axis + min_threshold=min_threshold)) else: - return self.set_weights_quantization_param({}) + self.set_weights_quantization_param({}) def has_weights_quantization_params(self) -> bool: """ @@ -359,7 +355,7 @@ def __eq__(self, other: Any) -> bool: Returns: Whether the objects are identical or not. """ - if not isinstance(other, NodeWeightsQuantizationConfig): + if not isinstance(other, WeightsAttrQuantizationConfig): return False return self.weights_quantization_fn == other.weights_quantization_fn and \ @@ -368,13 +364,9 @@ def __eq__(self, other: Any) -> bool: self.weights_error_method == other.weights_error_method and \ self.weights_quantization_method == other.weights_quantization_method and \ self.weights_n_bits == other.weights_n_bits and \ - self.weights_bias_correction == other.weights_bias_correction and \ - self.weights_second_moment_correction == other.weights_second_moment_correction and \ self.weights_per_channel_threshold == other.weights_per_channel_threshold and \ self.enable_weights_quantization == other.enable_weights_quantization and \ - self.min_threshold == other.min_threshold and \ - self.l_p_value == other.l_p_value and \ - self.simd_size == other.simd_size + self.l_p_value == other.l_p_value def __hash__(self): return hash((self.weights_quantization_fn, @@ -383,10 +375,207 @@ def __hash__(self): self.weights_error_method, self.weights_quantization_method, self.weights_n_bits, - self.weights_bias_correction, - self.weights_second_moment_correction, self.weights_per_channel_threshold, self.enable_weights_quantization, - self.min_threshold, - self.l_p_value, - self.simd_size)) + self.l_p_value)) + + +class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig): + """ + Holding a mapping between the node's weights attributes and their quantization configurations, + in addition to quantization parameters that are global for all attributes of the represented node. + """ + def __init__(self, qc: QuantizationConfig, + op_cfg: OpQuantizationConfig, + weights_channels_axis: Tuple[int, int], + node_attrs_list: List[str]): + """ + + Args: + qc: QuantizationConfig to create the node's config from. + op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration. + weights_channels_axis: Axis to quantize a node's weights attribute when quantizing per-channel. + node_attrs_list: A list of the node's weights attributes names. + + """ + self.min_threshold = qc.min_threshold + self.simd_size = op_cfg.simd_size + self.weights_second_moment_correction = qc.weights_second_moment_correction + self.weights_bias_correction = qc.weights_bias_correction + + # Initialize a quantization configuration for each of the node's attributes + self.attributes_config_mapping = {} + self.pos_attributes_config_mapping = {} + for attr in node_attrs_list: + if isinstance(attr, int): + # this is a positional attribute, so it needs to be handled separately. + # we assume that a positional attribute is quantized with the default configuration provided in the TPC. + if op_cfg.default_weight_attr_config.enable_weights_quantization: + Logger.critical(f"Quantizing constant weights is not supported.") + self.pos_attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc, + weights_attr_cfg=op_cfg.default_weight_attr_config, + weights_channels_axis=weights_channels_axis) + else: + # In Tensorflow, the attribute name is composed of the framework attribute name and the layer name, + # therefore, we need to look for the attribute in the op_cfg that is contained in the node attribute's name. + attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if k in attr} + if len(attrs_included_in_name) > 1: + Logger.error(f"Found multiple attribute in TPC OpConfig that are contained " + f"in the attribute name '{attr}'." + f"Please fix the TPC attribute names mapping such that each operator's attribute would " + f"have a unique matching name.") + if len(attrs_included_in_name) == 0: + attr_cfg = op_cfg.default_weight_attr_config + else: + attr_cfg = list(attrs_included_in_name.values())[0] + + self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(qc=qc, + weights_attr_cfg=attr_cfg, + weights_channels_axis=weights_channels_axis) + + def get_attr_config(self, attr_name: Union[str, int]) -> WeightsAttrQuantizationConfig: + """ + Returns a weights attribute config for an attribute that contains the given name. + If multiple attributes that contain the given name are found - looking for the exact name, otherwise, + fails with an error message. + If none attributes that contain the given name are found - fails with an error message. + + Args: + attr_name: The name of the attribute to get its quantization configuration. + + Returns: An attribute quantization configuration. + + """ + if attr_name is None: + Logger.error("Got 'None' attribute name for retrieving weights attribute quantization configuration.") + + if isinstance(attr_name, int): + # this is a positional attribute + attr_cfg = self.pos_attributes_config_mapping.get(attr_name) + else: + attrs_with_name = self._extract_config_for_attributes_with_name(attr_name) + attr_cfg = None + if len(attrs_with_name) == 1: + attr_cfg = [v for v in attrs_with_name.values()][0] + elif len(attrs_with_name) > 1: + Logger.warning(f"Found multiple weight attributes containing the name {attr_name}: " + f"{list(attrs_with_name.keys())}. Looking for an attributes with the exact name.") + # If no attribute with the exact name then an error would be thrown + attr_cfg = self.attributes_config_mapping.get(attr_name) + + if attr_cfg is None: + Logger.error(f"Weight attribute '{attr_name}' config could not be found.") + + return attr_cfg + + def set_attr_config(self, attr_name: Union[str, int], attr_qc: WeightsAttrQuantizationConfig): + """ + Adding a new attribute with quantization configuration to the node's weights configurations mapping. + + Args: + attr_name: The name of the attribute to set a quantization configuration to. + attr_qc: The quantization configuration to set. + + """ + if isinstance(attr_name, int): + self.pos_attributes_config_mapping[attr_name] = attr_qc + else: + self.attributes_config_mapping[attr_name] = attr_qc + + def has_attribute_config(self, attr_name: Union[str, int]) -> bool: + """ + Checks whether the node weights configuration contains a configuration for a given weights attribute. + + Args: + attr_name: The attribute name to check if a quantization configuration is defined for. + + Returns: True if the attribute exists in the attributes configuration mapping, False otherwise. + + """ + if isinstance(attr_name, int): + return self.pos_attributes_config_mapping.get(attr_name, False) + else: + saved_attr_name = self._extract_config_for_attributes_with_name(attr_name) + if len(saved_attr_name) >= 1: + return True + + return False + + def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, WeightsAttrQuantizationConfig]: + """ + Extract the saved attributes that contain the given attribute name. + Relevant to Tensorflow where attributes are presented with the layer name and index, + in addition to the attribute actual name. + + Args: + attr_name: An attribute to extract its saved name. + + Returns: A mapping between attributes that contain the given name to their configuration. + + """ + attrs_with_name = {k: v for k, v in self.attributes_config_mapping.items() if attr_name in k} + if len(attrs_with_name) > 1: + Logger.warning(f"Found multiple weight attributes containing the name {attr_name}: " + f"{list(attrs_with_name.keys())}.") + return attrs_with_name + + def set_quant_config_attr(self, parameter_name: str, parameter_value: Any, attr_name: str = None, + *args: List[Any], **kwargs: Dict[str, Any]): + """ + This method overrides the parent class set_quant_config_attr to enable setting a specific weights + attribute config parameter. + + Args: + attr_name: attribute name to change. + parameter_name: parameter name to change. + parameter_value: parameter value to change. + args: A list of additional arguments. + kwargs: A dictionary with additional key arguments. + + """ + + if attr_name is None: + super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(parameter_name, parameter_value, + *args, **kwargs) + else: + if self.has_attribute_config(attr_name): + attr_cfg = self.get_attr_config(attr_name) + if hasattr(attr_cfg, parameter_name): + setattr(attr_cfg, parameter_name, parameter_value) + else: + Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config of " + f"weights attribute {attr_name} and was not updated!") + else: + Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {parameter_name}.") + + def __eq__(self, other: Any) -> bool: + """ + Compares the object to another object to find if they are equal. + + Args: + other: An object to compare to. + + Returns: Whether the objects are identical or not. + + """ + if not isinstance(other, NodeWeightsQuantizationConfig): + return False + + return self.min_threshold == other.min_threshold and \ + self.simd_size == other.simd_size and \ + self.weights_second_moment_correction == other.weights_second_moment_correction and \ + self.weights_bias_correction == other.weights_bias_correction and \ + self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \ + all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k] + for k in self.attributes_config_mapping.keys()]) and \ + self.pos_attributes_config_mapping.keys() == other.pos_attributes_config_mapping.keys() and \ + all([self.pos_attributes_config_mapping[k] == other.pos_attributes_config_mapping[k] + for k in self.pos_attributes_config_mapping.keys()]) + + def __hash__(self): + return hash((self.min_threshold, + self.simd_size, + self.weights_second_moment_correction, + self.weights_bias_correction, + frozenset(self.attributes_config_mapping), + frozenset(self.pos_attributes_config_mapping))) diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py index 8f4eb2cbc..a0fe6458d 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py @@ -15,21 +15,17 @@ from tqdm import tqdm from typing import List -from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation -from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \ import get_activations_qparams from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \ - get_weights_qparams, get_channels_axis + get_weights_qparams from model_compression_toolkit.logger import Logger def calculate_quantization_params(graph: Graph, - fw_info: FrameworkInfo, nodes: List[BaseNode] = [], - specific_nodes: bool = False, - fw_impl: FrameworkImplementation = None): + specific_nodes: bool = False): """ For a graph, go over its nodes, compute quantization params (for both weights and activations according to the given framework info), and create and attach a NodeQuantizationConfig to each node (containing the @@ -39,12 +35,10 @@ def calculate_quantization_params(graph: Graph, a list of nodes should be passed as well. Args: - fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.) graph: Graph to compute its nodes' thresholds. nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph. specific_nodes: Flag to compute thresholds for only specific nodes. - fw_impl: FrameworkImplementation with specific framework implementations. """ @@ -57,14 +51,21 @@ def calculate_quantization_params(graph: Graph, for n in tqdm(nodes_list, "Calculating quantization params"): # iterate only nodes that we should compute their thresholds for candidate_qc in n.candidates_quantization_cfg: - if n.is_weights_quantization_enabled(): - # If node's weights should be quantized, we compute its weights' quantization parameters - output_channels_axis, _ = get_channels_axis(candidate_qc.weights_quantization_cfg, fw_info, n.type) - weights_params = get_weights_qparams(n.get_weights_by_keys(fw_impl.constants.KERNEL), - candidate_qc.weights_quantization_cfg, - output_channels_axis) - candidate_qc.weights_quantization_cfg.set_weights_quantization_param(weights_params) - candidate_qc.weights_quantization_cfg.weights_channels_axis = output_channels_axis + for attr in n.get_node_weights_attributes(): + if n.is_weights_quantization_enabled(attr): + # If the node's weights attribute should be quantized, we compute its quantization parameters + attr_cfg = candidate_qc.weights_quantization_cfg.get_attr_config(attr) + channels_axis = attr_cfg.weights_channels_axis + if channels_axis is not None: + output_channels_axis = channels_axis[0] + else: + output_channels_axis = None + weights_params = get_weights_qparams(n.get_weights_by_keys(attr), + candidate_qc.weights_quantization_cfg, + attr_cfg, + output_channels_axis) + attr_cfg.set_weights_quantization_param(weights_params) + if n.is_activation_quantization_enabled(): # If node's activations should be quantized as well, we compute its activation quantization parameters activation_params = get_activations_qparams( diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py index 370c9ecbc..a1e14a7c2 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py @@ -19,8 +19,8 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.defaultdict import DefaultDict from model_compression_toolkit.core.common.framework_info import FrameworkInfo -from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig - +from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \ + WeightsAttrQuantizationConfig # If the quantization config does not contain kernel channel mapping or the weights # quantization is not per-channel, we use a dummy channel mapping. @@ -29,6 +29,7 @@ def get_weights_qparams(kernel: np.ndarray, weights_quant_config: NodeWeightsQuantizationConfig, + attr_quant_config: WeightsAttrQuantizationConfig, output_channels_axis: int) -> Dict[Any, Any]: """ Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig @@ -37,27 +38,26 @@ def get_weights_qparams(kernel: np.ndarray, Args: kernel: Kernel to compute the quantization thresholds to. weights_quant_config: Weights quantization configuration to define how the thresholds are computed. + attr_quant_config: A specific weights attribute quantization configuration to get its params. output_channels_axis: Index of the kernel output channels dimension. Returns: A dictionary with the quantization threshold of the kernel. """ - if weights_quant_config.weights_quantization_params_fn is not None: - weights_params = weights_quant_config.weights_quantization_params_fn(kernel, - p=weights_quant_config.l_p_value, - n_bits=weights_quant_config.weights_n_bits, - per_channel=weights_quant_config.weights_per_channel_threshold and output_channels_axis is not None, - channel_axis=output_channels_axis, - min_threshold=weights_quant_config.min_threshold, - quant_error_method=weights_quant_config.weights_error_method) + if attr_quant_config.weights_quantization_params_fn is not None: + weights_params = attr_quant_config.weights_quantization_params_fn(kernel, + p=attr_quant_config.l_p_value, + n_bits=attr_quant_config.weights_n_bits, + per_channel=attr_quant_config.weights_per_channel_threshold and output_channels_axis is not None, + channel_axis=output_channels_axis, + min_threshold=weights_quant_config.min_threshold, + quant_error_method=attr_quant_config.weights_error_method) else: weights_params = {} return weights_params - - def _get_kernel_channels_mapping(fw_info:FrameworkInfo, use_dummy: bool) -> DefaultDict: """ @@ -78,33 +78,3 @@ def _get_kernel_channels_mapping(fw_info:FrameworkInfo, else: kernel_channels_mapping = fw_info.kernel_channels_mapping return kernel_channels_mapping - - - - -def get_channels_axis(weights_quant_config: NodeWeightsQuantizationConfig, - fw_info: FrameworkInfo, - node_type: type) -> Tuple[Any, Any]: - """ - Get the layer's kernel channels input/output indices. - - Args: - weights_quant_config: NodeWeightsQuantizationConfig object of the node we would like get - channels axis for. This is needed for whether to use dummy mapping or not. - fw_info: Framework info contains the kernel channels mapping. - node_type: Class to get its kernel's channels indices. - - Returns: - Class's kernel input/output channels indices. - """ - # If weights should be quantized per-channel but a kernel channels mapping is missing. - if weights_quant_config.weights_per_channel_threshold and \ - fw_info.kernel_channels_mapping is None: - Logger.warning('Weights Per Channel Quantization requires channel mapping function,' - ' but framework info does not contain one') - use_dummy = not weights_quant_config.weights_per_channel_threshold and not \ - weights_quant_config.weights_bias_correction - kernel_channels_mapping = _get_kernel_channels_mapping(fw_info, use_dummy) - output_channels_axis, input_channels_axis = kernel_channels_mapping.get(node_type) - return output_channels_axis, input_channels_axis - diff --git a/model_compression_toolkit/core/common/quantization/quantize_graph_weights.py b/model_compression_toolkit/core/common/quantization/quantize_graph_weights.py index a3967f728..7e89b32a4 100644 --- a/model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +++ b/model_compression_toolkit/core/common/quantization/quantize_graph_weights.py @@ -19,13 +19,11 @@ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common.graph.base_graph import Graph -from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc +from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_weights_attr_by_qc from model_compression_toolkit.logger import Logger -def quantize_graph_weights(graph: Graph, - fw_info: FrameworkInfo, - fw_impl: FrameworkImplementation) -> Graph: +def quantize_graph_weights(graph: Graph) -> Graph: """ Get a graph representing a model, and quantize its nodes' weights. Each node is quantized according to the passed framework info and quantization configuration. @@ -34,25 +32,23 @@ def quantize_graph_weights(graph: Graph, Args: graph: Graph to quantize its nodes. - fw_info: Framework information needed for quantizing the graph's nodes' weights and activations. - fw_impl: FrameworkImplementation with specific framework implementations. """ # Iterate over nodes in the graph and quantize each node's weights and activations # (according to operators groups in framework info). for n in graph.nodes(): - - if n.is_weights_quantization_enabled(): - quantized_kernel, io_channels_axes = get_quantized_kernel_by_weights_qc(fw_info, - n, - n.final_weights_quantization_cfg, - fw_impl=fw_impl) - - Logger.debug( - f'Node name: {n.name} has the following quantization params: ' - f'{str(n.final_weights_quantization_cfg.weights_quantization_params)}') - - # Set the kernel node to be the quantized kernel. - n.set_weights_by_keys(fw_impl.constants.KERNEL, quantized_kernel) + for attr in n.get_node_weights_attributes(): + if n.is_weights_quantization_enabled(attr): + quantized_attr, io_channels_axes = \ + get_quantized_weights_attr_by_qc(attr, + n, + n.final_weights_quantization_cfg.get_attr_config(attr)) + + Logger.debug( + f'Weights attribute: {attr} of node name: {n.name} has the following quantization params: ' + f'{str(n.final_weights_quantization_cfg.get_attr_config(attr).weights_quantization_params)}') + + # Set the attribute to be the quantized attribute. + n.set_weights_by_keys(attr, quantized_attr) return graph diff --git a/model_compression_toolkit/core/common/quantization/quantize_node.py b/model_compression_toolkit/core/common/quantization/quantize_node.py index 664ec7eec..3cdb8d24f 100644 --- a/model_compression_toolkit/core/common/quantization/quantize_node.py +++ b/model_compression_toolkit/core/common/quantization/quantize_node.py @@ -14,51 +14,44 @@ # ============================================================================== -from model_compression_toolkit.core import common from model_compression_toolkit.logger import Logger -from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation -from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common.graph.base_node import BaseNode -from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig -from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation \ - import \ - get_channels_axis +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig -def get_quantized_kernel_by_weights_qc(fw_info: FrameworkInfo, - n: BaseNode, - weights_qc: NodeWeightsQuantizationConfig, - fw_impl: FrameworkImplementation): +def get_quantized_weights_attr_by_qc(attr_name: str, + n: BaseNode, + weights_qc: WeightsAttrQuantizationConfig): """ - For a node and weights quantization configuration, compute - the quantized kernel of the node and return it and the input/output channels indices. + For a weights attribute and weights attribute quantization configuration, compute + the quantized weights of the node's attribute and return it + and the input/output channels indices (if relevant, o.w. None). Args: - fw_info: A FrameworkInfo object Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). - n: Node to quantize its kernel. - weights_qc: Weight quantization configuration to use for the quantization. - fw_impl: FrameworkImplementation with specific framework implementations. + attr_name: The name of the attribute to quantize. + n: Node to quantize its weights attribute. + weights_qc: Weight attribute quantization configuration to use for the quantization. Returns: A quantized kernel of the node using a weights quantization configuration. """ - # If weights should be quantized per-channel but a kernel channels mapping is missing. - if weights_qc.weights_per_channel_threshold and fw_info.kernel_channels_mapping is \ - None: - Logger.warning( - 'Weights Per Channel Quantization requires channel mapping function but framework info ' - 'does not contain one') - output_channels_axis, input_channels_axis = get_channels_axis(weights_qc, - fw_info, - n.type) - - Logger.debug(f'quantizing {n.name} with {weights_qc.weights_n_bits} bits') - quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(fw_impl.constants.KERNEL), + channels_axis = weights_qc.weights_channels_axis + if channels_axis is not None: + # switching output and input channel axis order in the tuple because this is what + # the caller of this function expect. The new order is: (input, output) + channels_axis = (channels_axis[1], channels_axis[0]) + output_channels_axis = channels_axis[1] + else: + channels_axis = None + output_channels_axis = None + + Logger.debug(f'quantizing layer {n.name} attribute {attr_name} with {weights_qc.weights_n_bits} bits') + quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(attr_name), n_bits=weights_qc.weights_n_bits, signed=True, quantization_params=weights_qc.weights_quantization_params, per_channel=weights_qc.weights_per_channel_threshold, output_channels_axis=output_channels_axis) - return quantized_kernel, (input_channels_axis, output_channels_axis) \ No newline at end of file + return quantized_kernel, channels_axis diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index b3545636b..41e141f76 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -15,7 +15,7 @@ import copy -from typing import List +from typing import List, Tuple from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.logger import Logger @@ -76,17 +76,20 @@ def set_quantization_configs_to_node(node: BaseNode, node_qc_options = node.get_qco(tpc) # Create QC candidates for weights and activation combined - weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)[0] + weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type) node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config, fw_info, weight_channel_axis, node_qc_options, - node.type, + node, mixed_precision_enable=mixed_precision_enable) + # sorting the candidates by kernel attribute weights number of bits first and then by activation number of bits + # (in reversed order). since only kernel attribute is quantized in weights mixed precision, + # if the node doesn't have a kernel attribute, we only sort by activation_n_bits. + node.sort_node_candidates(fw_info) + for candidate_qc in node.candidates_quantization_cfg: - candidate_qc.weights_quantization_cfg.enable_weights_quantization = \ - candidate_qc.weights_quantization_cfg.enable_weights_quantization and node.has_weights_to_quantize(fw_info) candidate_qc.activation_quantization_cfg.enable_activation_quantization = \ candidate_qc.activation_quantization_cfg.enable_activation_quantization and node.get_has_activation() @@ -121,9 +124,9 @@ def create_node_activation_qc(qc: QuantizationConfig, def _create_node_single_candidate_qc(qc: QuantizationConfig, fw_info: FrameworkInfo, - weight_channel_axis: int, + weight_channel_axis: Tuple[int, int], op_cfg: OpQuantizationConfig, - kernel_attr: str) -> CandidateNodeQuantizationConfig: + node_attrs_list: List[str]) -> CandidateNodeQuantizationConfig: """ Create quantization configuration candidate from a QuantizationConfig object. Creates both weights and activation quantization configurations @@ -133,51 +136,45 @@ def _create_node_single_candidate_qc(qc: QuantizationConfig, qc: QuantizationConfig to create the node's config from. fw_info: Information about the specific framework the node was created from (e.g., whether its weights/activations should be quantized) - weight_channel_axis: Output channel index of the node's kernel. + weight_channel_axis: (Output, Input) channel index of the node's kernel. op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration. - kernel_attr: The name of the kernel attribute of the node, - TODO: kernel_attr should be removed once enabling attributes quantization (because this function would create - candidate for all attributes not specifically for the kernel + node_attrs_list: A list of the node's weights attributes names. Returns: a CandidateNodeQuantizationConfig object with both weights and activation quantization config objects. """ - # get attributes for weights quantization. - # if the node doesn't have a specified kernel config we use the default attribute config for quantization. - # TODO: This should be the behavior for all attributes that are not specified in the attribute config mapping, - # which currently disables the quantization of the weights attribute. - weights_cfg = op_cfg.attr_weights_configs_mapping.get(kernel_attr, op_cfg.default_weight_attr_config) - - weights_quantization_fn = get_weights_quantization_fn(weights_cfg.weights_quantization_method) - - if weights_quantization_fn is None: - Logger.critical(f'Unknown quantization method for weights for quantizing attribute: {kernel_attr}') # pragma: no cover - - weights_quantization_params_fn = get_weights_quantization_params_fn(weights_cfg.weights_quantization_method) + # parameters for weights attributes quantization are set within CandidateNodeQuantizationConfig initialization - # get attributes for activation quantization + # get parameters for activation quantization activation_quantization_fn = fw_info.activation_quantizer_mapping.get(op_cfg.activation_quantization_method) if activation_quantization_fn is None: Logger.critical('Unknown quantization method for activations') # pragma: no cover activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method) + # TODO: remove this validation and warning once enabling all attributes quantization by default + attrs_with_enabled_quantization = [attr for attr, cfg in op_cfg.attr_weights_configs_mapping.items() + if cfg.enable_weights_quantization] + if len(attrs_with_enabled_quantization) > 1: + Logger.warning(f"Multiple weights attributes quantization is enabled via the provided TPC." + f"Quantizing any attribute other than the kernel is experimental " + f"and may be subject to unstable behavior." + f"Attributes with enabled weights quantization: {attrs_with_enabled_quantization}.") + return CandidateNodeQuantizationConfig(qc=qc, op_cfg=op_cfg, activation_quantization_fn=activation_quantization_fn, activation_quantization_params_fn=activation_quantization_params_fn, - weights_quantization_fn=weights_quantization_fn, - weights_quantization_params_fn=weights_quantization_params_fn, weights_channels_axis=weight_channel_axis, - weights_cfg=weights_cfg) + node_attrs_list=node_attrs_list) def _create_node_candidates_qc(qc: QuantizationConfig, fw_info: FrameworkInfo, - weight_channel_axis: int, + weight_channel_axis: Tuple[int, int], node_qc_options: QuantizationConfigOptions, - node_type: type, + node: BaseNode, mixed_precision_enable: bool = False) -> List[CandidateNodeQuantizationConfig]: """ Create a list of candidates of weights and activation quantization configurations for a node. @@ -185,9 +182,9 @@ def _create_node_candidates_qc(qc: QuantizationConfig, Args: qc: Quantization configuration the quantization process should follow. fw_info: Framework information (e.g., which layers should have their kernels' quantized). - weight_channel_axis: Output channel index of the node's kernel. + weight_channel_axis: (Output, Input) channel index of the node's kernel. node_qc_options: QuantizationConfigOptions for the node with quantization candidates information. - node_type: The type of the layer that the node represents. + node: A node to set quantization configuration candidates to. mixed_precision_enable: is mixed precision enabled Returns: @@ -195,31 +192,22 @@ def _create_node_candidates_qc(qc: QuantizationConfig, """ candidates = [] - - # TODO: Currently, we are using fw_info to get the kernel attribute, but this would changed once we enable multi - # attribute quantization via AttributeQuantizationConfig class (needs to be implemented) - - kernel_attr = fw_info.get_kernel_op_attributes(node_type) - assert len(kernel_attr) == 1 - kernel_attr = kernel_attr[0] + node_attrs_list = node.get_node_weights_attributes() if mixed_precision_enable: for op_cfg in node_qc_options.quantization_config_list: - candidate_nbits_qc = copy.deepcopy(qc) - candidates.append(_create_node_single_candidate_qc(candidate_nbits_qc, + candidate_qc = copy.deepcopy(qc) + candidates.append(_create_node_single_candidate_qc(candidate_qc, fw_info, weight_channel_axis, op_cfg, - kernel_attr)) - # sorting the candidates by weights number of bits first and then by activation number of bits - # (in reversed order) - candidates.sort(key=lambda c: (c.weights_quantization_cfg.weights_n_bits, - c.activation_quantization_cfg.activation_n_bits), reverse=True) + node_attrs_list)) + else: candidates.append(_create_node_single_candidate_qc(qc, fw_info, weight_channel_axis, node_qc_options.base_config, - kernel_attr)) + node_attrs_list)) return candidates diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py index 9b183ced4..3a37a97d9 100644 --- a/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py @@ -14,9 +14,12 @@ # ============================================================================== import copy +from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig +from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph, @@ -37,17 +40,21 @@ def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph, graph = copy.deepcopy(graph_to_apply_bias_correction) for n in graph.nodes: - if n.is_weights_quantization_enabled() and core_config.quantization_config.weights_bias_correction \ - and not n.final_weights_quantization_cfg.weights_second_moment_correction: + # bias correction is only relevant for nodes with kernel op + kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0] + if core_config.quantization_config.weights_bias_correction and kernel_attr is not None and \ + n.is_weights_quantization_enabled(kernel_attr) and \ + not n.final_weights_quantization_cfg.weights_second_moment_correction: # If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg, # a bias correction term was calculated during model preparation, and is used now in the node's bias term. if n.final_weights_quantization_cfg.weights_bias_correction: - _apply_bias_correction_to_node(n, fw_impl) + _apply_bias_correction_to_node(n, fw_impl, core_config.quantization_config) return graph -def _apply_bias_correction_to_node(node:BaseNode, - fw_impl: FrameworkImplementation): +def _apply_bias_correction_to_node(node: BaseNode, + fw_impl: FrameworkImplementation, + qc: QuantizationConfig): """ Set new bias to node using the bias-correction term that is stored in the final weights quantization configuration. @@ -55,6 +62,7 @@ def _apply_bias_correction_to_node(node:BaseNode, Args: node: Node to set its corrected bias after bias-correction. fw_impl: FrameworkImplementation object with a specific framework methods implementation. + qc: QuantizationConfig containing parameters of how the model should be quantized. """ correction = node.final_weights_quantization_cfg.bias_corrected @@ -64,6 +72,13 @@ def _apply_bias_correction_to_node(node:BaseNode, if bias is not None: # If the layer has bias, we subtract the correction from original bias node.set_weights_by_keys(fw_impl.constants.BIAS, node.get_weights_by_keys(fw_impl.constants.BIAS) - correction) - else: # It the layer has no bias, we consider it as if it has and its value is 0. + else: + # If the layer has no bias, we consider it as if it has and its value is 0 and add a "dummy" attribute + # configuration with disabled quantization. node.set_weights_by_keys(fw_impl.constants.BIAS, - correction) node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node. + node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, + WeightsAttrQuantizationConfig( + qc, + AttributeQuantizationConfig( + enable_weights_quantization=False))) diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py index 041189da5..6d8f0e189 100644 --- a/model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py @@ -76,9 +76,7 @@ def quantized_model_builder_for_second_moment_correction(graph: common.Graph, Returns: Quantized model for second moment correction. """ - quantized_tg = quantize_graph_weights(graph, - fw_info=fw_info, - fw_impl=fw_impl) + quantized_tg = quantize_graph_weights(graph) quantized_model, user_info = fw_impl.model_builder(quantized_tg, mode=ModelBuilderMode.FLOAT, diff --git a/model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py b/model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py index 7e4a56ca9..6f488104b 100644 --- a/model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py @@ -13,22 +13,19 @@ # limitations under the License. # ============================================================================== -import copy from typing import Any import numpy as np -from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.core.common import BaseNode, Graph -from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_kernel_by_weights_qc +from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_weights_attr_by_qc from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector from model_compression_toolkit.logger import Logger def compute_bias_correction_of_graph(graph: Graph, - core_config: CoreConfig, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation) -> Graph: """ @@ -38,7 +35,6 @@ def compute_bias_correction_of_graph(graph: Graph, Args: graph: Graph with nodes to compute the bias correction for each node's weights quantization configuration candidates. - core_config: CoreConfig containing parameters of how the model should be quantized. fw_info: Framework info like lists of nodes their kernel should quantized. fw_impl: FrameworkImplementation object with a specific framework methods implementation. @@ -48,15 +44,21 @@ def compute_bias_correction_of_graph(graph: Graph, """ for n in graph.nodes: - if n.is_weights_quantization_enabled() and core_config.quantization_config.weights_bias_correction: - _compute_bias_correction_per_candidate_qc(n, - fw_info, - graph.get_in_stats_collector(n), - fw_impl=fw_impl) + # Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute + # name out of all the weights attributes of the node. + if fw_info.is_kernel_op(n.type): + kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0] + if n.is_weights_quantization_enabled(kernel_attr): + _compute_bias_correction_per_candidate_qc(n, + kernel_attr, + fw_info, + graph.get_in_stats_collector(n), + fw_impl=fw_impl) return graph def _compute_bias_correction_per_candidate_qc(node: BaseNode, + kernel_attr: str, fw_info: FrameworkInfo, node_in_stats_collector: BaseStatsCollector, fw_impl: FrameworkImplementation): @@ -66,6 +68,7 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode, Args: node: Node to compute the bias correction for its different candidates. + kernel_attr: The name of the kernel attribute of the node. fw_info: Framework info like lists of nodes their kernel should quantized. node_in_stats_collector: Statistics collector of the node for the mean per-channel. fw_impl: FrameworkImplementation object with a specific framework methods implementation. @@ -73,25 +76,24 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode, """ for candidate_qc in node.candidates_quantization_cfg: - if candidate_qc.weights_quantization_cfg.enable_weights_quantization and not \ + if candidate_qc.weights_quantization_cfg.weights_bias_correction and not \ candidate_qc.weights_quantization_cfg.weights_second_moment_correction: - quantized_kernel, io_channels_axes = get_quantized_kernel_by_weights_qc(fw_info, - node, - candidate_qc.weights_quantization_cfg, - fw_impl=fw_impl) - - # If a kernel was quantized and weights bias correction is enabled in n.quantization_cfg, - # a bias correction term is being calculated and used in the node's bias term. - if candidate_qc.weights_quantization_cfg.weights_bias_correction: - bias_correction_term = _get_bias_correction_term_of_node(io_channels_axes[0], - node, - node_in_stats_collector, - io_channels_axes[1], - quantized_kernel, - fw_impl=fw_impl) - - # Store the correction term to use it later, - candidate_qc.weights_quantization_cfg.bias_corrected = bias_correction_term + + quantized_kernel, io_channels_axes = get_quantized_weights_attr_by_qc(kernel_attr, + node, + candidate_qc.weights_quantization_cfg + .get_attr_config(kernel_attr)) + + bias_correction_term = _get_bias_correction_term_of_node(io_channels_axes[0], + node, + node_in_stats_collector, + io_channels_axes[1], + quantized_kernel, + fw_impl=fw_impl) + + # Store the correction term to use it later, + candidate_qc.weights_quantization_cfg.bias_corrected = bias_correction_term + def is_non_positive_integer(x: float) -> bool: """ @@ -103,6 +105,7 @@ def is_non_positive_integer(x: float) -> bool: """ return x < 1 or int(x) != x + def _compute_bias_correction(kernel: np.ndarray, quantized_kernel: np.ndarray, in_statistics_container: BaseStatsCollector, diff --git a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py index 6d7d23129..e11aa0290 100644 --- a/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +++ b/model_compression_toolkit/core/common/statistics_correction/statistics_correction.py @@ -57,7 +57,6 @@ def statistics_correction_runner(transformed_graph: Graph, # Compute bias correction to nodes' config candidates ######################################################## tg_with_bias = compute_bias_correction_of_graph(tg_with_bias, - core_config, fw_info, fw_impl) diff --git a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py index 861bb5d3a..482fa4fce 100644 --- a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +++ b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py @@ -19,12 +19,15 @@ import numpy as np +from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig from model_compression_toolkit.core import common +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.graph.base_graph import Graph from model_compression_toolkit.core.common.graph.base_node import BaseNode from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher -from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod +from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, \ + AttributeQuantizationConfig class BatchNormalizationReconstruction(common.BaseSubstitution): @@ -92,7 +95,12 @@ def substitute(self, # This feature disabled for models with weights quantization method of Power of 2 for qc in source_node.candidates_quantization_cfg: - if qc.weights_quantization_cfg.weights_quantization_method == QuantizationMethod.POWER_OF_TWO: + # this feature is relevant only for layers with kernel op + kernel_attr = graph.fw_info.get_kernel_op_attributes(source_node.type) + if kernel_attr is None: + Logger.error(f"Can't preform BatchNorm reconstruction on a node {source_node.name} without a kernel op.") + if (qc.weights_quantization_cfg.get_attr_config(kernel_attr[0]).weights_quantization_method + == QuantizationMethod.POWER_OF_TWO): Logger.warning("Second moment statistics correction feature disabled for models with weights " "quantization method of Power of 2") for qc_inner in source_node.candidates_quantization_cfg: @@ -119,8 +127,21 @@ def substitute(self, bn_node.candidates_quantization_cfg = copy.deepcopy(source_node.candidates_quantization_cfg) for qc in bn_node.candidates_quantization_cfg: - qc.weights_quantization_cfg.enable_weights_quantization = False qc.activation_quantization_cfg.enable_activation_quantization = False + for attr in bn_node.get_node_weights_attributes(): + if qc.weights_quantization_cfg.has_attribute_config(attr): + # we only create a BN layer to collect statistics, so we don't need to quantize anything, + # but we do need to add the BN attributes to the reconstructed node. + qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False + else: + # setting a "dummy" attribute configuration with disabled quantization. + # TODO: once enabling BN attributes quantization, need to figure out if thie + # reconstructed node BN attributes need to be quantized and how. + qc.weights_quantization_cfg.set_attr_config(attr, + WeightsAttrQuantizationConfig( + QuantizationConfig(), + AttributeQuantizationConfig( + enable_weights_quantization=False))) graph.reconnect_out_edges(current_node=source_node, new_node=bn_node) graph.replace_output_node(current_node=source_node, new_node=bn_node) diff --git a/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py b/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py index d8abf3765..0f47087c4 100644 --- a/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +++ b/model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py @@ -159,7 +159,7 @@ def substitute(self, graph.remove_node(bn_node) graph.remove_node(source_node) - self._calc_weights_quantization_params(conv_bn, weights_scale) + self._calc_weights_quantization_params(conv_bn, weights_scale, graph.fw_info) assert num_nodes_before_substitution - len(graph.nodes) == 1 assert num_edges_before_substitution - len(graph.edges) == 1 @@ -167,31 +167,36 @@ def substitute(self, def _calc_weights_quantization_params(self, conv_bn: BaseNode, - weights_scale: np.ndarray): + weights_scale: np.ndarray, + fw_info): """ Update node weights quantization params. Args: conv_bn: Convolution node to update the weights quantization params. weights_scale: Weight scale factor in which to multiply the conv node's weight. + fw_info: FrameworkInfo object with information about the specific framework's model """ + # Conv layer is ensured to have a kernel attribute + kernel_attr = fw_info.get_kernel_op_attributes(conv_bn.type)[0] + conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(kernel_attr) # In case of SYMMETRIC weight quantization method, we update the threshold by weights_scale - if conv_bn.final_weights_quantization_cfg.weights_quantization_method == QuantizationMethod.SYMMETRIC: - original_threshold = conv_bn.final_weights_quantization_cfg.weights_quantization_params[THRESHOLD] - corr_dict = copy.deepcopy(conv_bn.final_weights_quantization_cfg.weights_quantization_params) + if conv_bn_kernel_cfg.weights_quantization_method == QuantizationMethod.SYMMETRIC: + original_threshold = conv_bn_kernel_cfg.weights_quantization_params[THRESHOLD] + corr_dict = copy.deepcopy(conv_bn_kernel_cfg.weights_quantization_params) corr_threshold, _ = self.update_kernel_for_bn_refusing_fn(conv_bn, original_threshold, weights_scale) corr_dict[THRESHOLD] = corr_threshold - conv_bn.final_weights_quantization_cfg.set_weights_quantization_param(corr_dict) + conv_bn_kernel_cfg.set_weights_quantization_param(corr_dict) # In case of UNIFORM weight quantization method, we update the range_min, range_max by weights_scale - elif conv_bn.final_weights_quantization_cfg.weights_quantization_method == QuantizationMethod.UNIFORM: - corr_dict = copy.deepcopy(conv_bn.final_weights_quantization_cfg.weights_quantization_params) - original_range_min = conv_bn.final_weights_quantization_cfg.weights_quantization_params[RANGE_MIN] + elif conv_bn_kernel_cfg.weights_quantization_method == QuantizationMethod.UNIFORM: + corr_dict = copy.deepcopy(conv_bn_kernel_cfg.weights_quantization_params) + original_range_min = conv_bn_kernel_cfg.weights_quantization_params[RANGE_MIN] corr_range_min, _ = self.update_kernel_for_bn_refusing_fn(conv_bn, original_range_min, weights_scale) - original_range_max = conv_bn.final_weights_quantization_cfg.weights_quantization_params[RANGE_MAX] + original_range_max = conv_bn_kernel_cfg.weights_quantization_params[RANGE_MAX] corr_range_max, _ = self.update_kernel_for_bn_refusing_fn(conv_bn, original_range_max, weights_scale) corr_dict[RANGE_MIN] = corr_range_min corr_dict[RANGE_MAX] = corr_range_max - conv_bn.final_weights_quantization_cfg.set_weights_quantization_param(corr_dict) + conv_bn_kernel_cfg.set_weights_quantization_param(corr_dict) else: Logger.exception("Second moment statistics correction feature disabled for models with weights " diff --git a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py index 0116d9163..ec96ab804 100644 --- a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +++ b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py @@ -16,11 +16,14 @@ import numpy as np from typing import List, Tuple, Any, Callable +from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig +from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher -from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod +from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod, \ + AttributeQuantizationConfig from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \ set_quantization_configs_to_node from model_compression_toolkit.core.common.quantization.core_config import CoreConfig @@ -63,6 +66,12 @@ def op2d_bias_correction(op2d_node: BaseNode, if bias is None: bias = 0.0 op2d_node.framework_attr[bias_flag_str] = True + # Add an attribute quantization configuration to the newly added bias attribute, with disabled quantization + for qc in op2d_node.candidates_quantization_cfg: + qc.weights_quantization_cfg.set_attr_config(bias_flag_str, + WeightsAttrQuantizationConfig(QuantizationConfig(), + AttributeQuantizationConfig( + enable_weights_quantization=False))) # Each node adds a different noise due to the shifting. It depends on the # dimensions of the kernel, thus the correction term is a function of @@ -348,8 +357,9 @@ def shift_negative_function(graph: Graph, mixed_precision_enable=core_config.mixed_precision_enable) for candidate_qc in pad_node.candidates_quantization_cfg: - candidate_qc.weights_quantization_cfg.enable_weights_quantization = False candidate_qc.activation_quantization_cfg.enable_activation_quantization = False + for attr in pad_node.get_node_weights_attributes(): + candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False # Insert a pad node between the add node to the op2d, and create statistics for the pad node insert_node_before_node(graph, @@ -382,7 +392,8 @@ def shift_negative_function(graph: Graph, add_node_qco = add_node.get_qco(graph.tpc).quantization_config_list for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg): - candidate_qc.weights_quantization_cfg.enable_weights_quantization = False + for attr in add_node.get_node_weights_attributes(): + candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config, fw_info, diff --git a/model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py b/model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py index f796450f0..f6d74b505 100644 --- a/model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +++ b/model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py @@ -57,6 +57,7 @@ def substitute(self, # the actual arguments values are irrelevant because they are being overridden or not used v_node = VirtualActivationWeightsNode(act_node, weights_node, + fw_info=graph.fw_info, **weights_node.__dict__) # Update graph diff --git a/model_compression_toolkit/core/common/substitutions/weights_activation_split.py b/model_compression_toolkit/core/common/substitutions/weights_activation_split.py index 4a2bd62bb..3716b8eca 100644 --- a/model_compression_toolkit/core/common/substitutions/weights_activation_split.py +++ b/model_compression_toolkit/core/common/substitutions/weights_activation_split.py @@ -49,21 +49,27 @@ def substitute(self, Returns: Graph after applying the substitution. """ - - if not node.is_all_weights_candidates_equal() and not node.is_all_activation_candidates_equal(): + # The decomposition works on linear nodes, that is, nodes with kernel ops + kernel_attr = graph.fw_info.get_kernel_op_attributes(node.type)[0] + if kernel_attr is None: + Logger.error(f"Trying to split node weights and activation, but node " + f"{node.name} doesn't have a kernel attribute.") + if not node.is_all_weights_candidates_equal(kernel_attr) and not node.is_all_activation_candidates_equal(): # Node has both different weights and different activation configuration candidates - weights_bits = [c.weights_quantization_cfg.weights_n_bits for c in node.get_unique_weights_candidates()] + weights_bits = [c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits + for c in node.get_unique_weights_candidates(kernel_attr)] activation_bits = [c.activation_quantization_cfg.activation_n_bits for c in node.get_unique_activation_candidates()] expected_candidates = list(itertools.product(weights_bits, activation_bits)) - all_candidates_bits = [(c.weights_quantization_cfg.weights_n_bits, - c.activation_quantization_cfg.activation_n_bits) for c in node.candidates_quantization_cfg] + all_candidates_bits = [(c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits, + c.activation_quantization_cfg.activation_n_bits) + for c in node.candidates_quantization_cfg] if not set(expected_candidates).issubset(all_candidates_bits): # Node is not composite, therefore, can't be split Logger.critical(f"The graph contains a node {node.name} with non composite candidates." f"In order to run mixed-precision search with BOPS target KPI, " f"all model layers should be composite.") # pragma: no cover - weights_node = VirtualSplitWeightsNode(node) + weights_node = VirtualSplitWeightsNode(node, kernel_attr) activation_node = VirtualSplitActivationNode(node, self.activation_layer_type, self.fw_attr) # Update graph diff --git a/model_compression_toolkit/core/common/visualization/nn_visualizer.py b/model_compression_toolkit/core/common/visualization/nn_visualizer.py index fc4e97f82..7652bd96f 100644 --- a/model_compression_toolkit/core/common/visualization/nn_visualizer.py +++ b/model_compression_toolkit/core/common/visualization/nn_visualizer.py @@ -67,7 +67,7 @@ def __init__(self, """ self.graph_float = graph_float - self.graph_quantized = quantize_graph_weights(graph_float, fw_info=fw_info, fw_impl=fw_impl) + self.graph_quantized = quantize_graph_weights(graph_float) self.fw_impl = fw_impl self.fw_info = fw_info diff --git a/model_compression_toolkit/core/common/visualization/tensorboard_writer.py b/model_compression_toolkit/core/common/visualization/tensorboard_writer.py index 447503dd6..90967c3e1 100644 --- a/model_compression_toolkit/core/common/visualization/tensorboard_writer.py +++ b/model_compression_toolkit/core/common/visualization/tensorboard_writer.py @@ -230,7 +230,7 @@ def __get_node_weights_attr(n: BaseNode) -> Dict[str, Any]: if n.final_weights_quantization_cfg is not None: attr.update(n.final_weights_quantization_cfg.__dict__) elif n.candidates_quantization_cfg is not None: - attr.update(n.get_unified_weights_candidates_dict()) + attr.update(n.get_unified_weights_candidates_dict(self.fw_info)) return attr def __get_node_attr(n: BaseNode) -> Dict[str, Any]: diff --git a/model_compression_toolkit/core/exporter.py b/model_compression_toolkit/core/exporter.py index cf8762321..52950df8b 100644 --- a/model_compression_toolkit/core/exporter.py +++ b/model_compression_toolkit/core/exporter.py @@ -45,9 +45,7 @@ def _quantize_model(tg: Graph, Quantized model in the input framework, and information the user may need in order to use the quantized model. """ - quantized_tg = quantize_graph_weights(tg, - fw_info=fw_info, - fw_impl=fw_impl) + quantized_tg = quantize_graph_weights(tg) if tb_w is not None: tb_w.add_graph(quantized_tg, 'after_quantization') ###################################### diff --git a/model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py b/model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py index 87b0603fd..b87dfee37 100644 --- a/model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +++ b/model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py @@ -88,30 +88,36 @@ def mixed_precision_wrapper(self, """ - weights_conf_nodes_names = [n.name for n in self.graph.get_weights_configurable_nodes()] - - if n.is_weights_quantization_enabled(): - kernel_attributes = self.fw_info.get_kernel_op_attributes(n.type) + kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0] + if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr): + weights_conf_nodes_names = [node.name for node in self.graph.get_weights_configurable_nodes(self.fw_info)] if n.name in weights_conf_nodes_names: return KerasQuantizationWrapper(layer, - weights_quantizers={attr: ConfigurableWeightsQuantizer( - **self._get_weights_configurable_quantizer_kwargs(n, attr)) - for attr in kernel_attributes}) + weights_quantizers={ + kernel_attr: ConfigurableWeightsQuantizer( + **self._get_weights_configurable_quantizer_kwargs(n, + kernel_attr))}) else: - node_weights_qc = n.get_unique_weights_candidates() + # TODO: Do we want to include other quantized attributes that are not + # the kernel attribute in the mixed precision model? + # Currently, we only consider kernel attribute quantization (whether it is in mixed precision + # or single precision). + node_weights_qc = n.get_unique_weights_candidates(kernel_attr) if not len(node_weights_qc) == 1: Logger.error(f"Expecting node {n.name} to have a unique weights configuration " # pragma: no cover f"but {len(node_weights_qc)} different configurations exist.") quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights, - node_weights_qc[0].weights_quantization_cfg.weights_quantization_method, + node_weights_qc[0].weights_quantization_cfg + .get_attr_config(kernel_attr) + .weights_quantization_method, BaseKerasInferableQuantizer) kwargs = get_inferable_quantizer_kwargs(node_weights_qc[0].weights_quantization_cfg, - QuantizationTarget.Weights) + QuantizationTarget.Weights, + kernel_attr) return KerasQuantizationWrapper(layer, - weights_quantizers={attr: quantier_for_node(**kwargs) - for attr in kernel_attributes}) + weights_quantizers={kernel_attr: quantier_for_node(**kwargs)}) return layer @@ -130,7 +136,7 @@ def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> assert n.candidates_quantization_cfg is not None, f"Node {n.name} candidates_quantization_cfg is None" node_q_cfg_candidates = n.candidates_quantization_cfg # sort by descending bit width so using indices would be easier - node_q_cfg_candidates.sort(key=lambda x: (x.weights_quantization_cfg.weights_n_bits, + node_q_cfg_candidates.sort(key=lambda x: (x.weights_quantization_cfg.get_attr_config(attr).weights_n_bits, x.activation_quantization_cfg.activation_n_bits), reverse=True) float_weights = n.get_weights_by_keys(attr) @@ -144,7 +150,8 @@ def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> return {'node_q_cfg': node_q_cfg_candidates, 'float_weights': float_weights, - 'max_candidate_idx': max_candidate_idx + 'max_candidate_idx': max_candidate_idx, + 'kernel_attr': attr, } def mixed_precision_activation_holder(self, n: BaseNode) -> KerasActivationQuantizationHolder: @@ -165,21 +172,26 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> KerasActivationQuant activation_quantizers = [] if n.is_activation_quantization_enabled(): num_of_outputs = len(n.output_shape) if isinstance(n.output_shape, list) else 1 + if n.name in activation_conf_nodes_names: assert n.candidates_quantization_cfg is not None, f"Node {n.name} candidates_quantization_cfg is None" node_q_cfg_candidates = n.candidates_quantization_cfg - # sort by descending bit width so using indices would be easier - node_q_cfg_candidates.sort(key=lambda x: (x.weights_quantization_cfg.weights_n_bits, - x.activation_quantization_cfg.activation_n_bits), - reverse=True) + + # sorting the candidates by kernel attribute weights number of bits first and then by + # activation number of bits (in reversed order). + # since only kernel attribute is quantized in weights mixed precision, + # if the node doesn't have a kernel attribute, we only sort by activation_n_bits. + n.sort_node_candidates(self.fw_info) max_cfg_candidates = n.find_max_candidates_indices() assert len(max_cfg_candidates) == 1, \ f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates" max_candidate_idx = max_cfg_candidates[0] + kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0] activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates, - 'max_candidate_idx': max_candidate_idx})] \ + 'max_candidate_idx': max_candidate_idx, + 'kernel_attr': kernel_attr})] \ * num_of_outputs else: node_act_qc = n.get_unique_activation_candidates() @@ -219,7 +231,7 @@ def build_model(self) -> Tuple[Model, UserInformation, # creating a mapping between graph nodes and model's layers for mixed precision configurability conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model.layers) - for n in self.graph.get_configurable_sorted_nodes()} + for n in self.graph.get_configurable_sorted_nodes(self.fw_info)} return model, user_info, conf_node2layers @@ -257,7 +269,7 @@ def _get_activation_quant_layers(n: BaseNode, layers_list: List[Layer]) -> List[ def _find_layers_in_model_by_node(self, n: BaseNode, layers_list: List[Layer]) -> \ List[Union[KerasQuantizationWrapper, KerasActivationQuantizationHolder]]: """ - Retries layers from an MP model that are matching to the given graph node, that is, these are either + Retrieves layers from an MP model that are matching to the given graph node, that is, these are either KerasQuantizationWrapper layers or KerasActivationQuantizationHolder layers that are responsible for the graph configurable model quantization. @@ -268,7 +280,9 @@ def _find_layers_in_model_by_node(self, n: BaseNode, layers_list: List[Layer]) - Returns: A list of layers that responsible for the node's quantization. """ - weights_quant = n.is_weights_quantization_enabled() + # Only layers with kernel op are considered weights configurable + kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0] + weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr) act_quant = n.is_activation_quantization_enabled() if weights_quant and not act_quant: diff --git a/model_compression_toolkit/core/keras/constants.py b/model_compression_toolkit/core/keras/constants.py index 133eb2694..502681366 100644 --- a/model_compression_toolkit/core/keras/constants.py +++ b/model_compression_toolkit/core/keras/constants.py @@ -73,7 +73,7 @@ F_STRIDED_SLICE_END = 'end_mask' # Layers variables names: -KERNEL = 'kernel' +KERNEL: str = 'kernel' DEPTHWISE_KERNEL = 'depthwise_kernel' BIAS = 'bias' GAMMA = 'gamma' diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py index 7a2636a01..637bfe210 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py @@ -21,6 +21,7 @@ from model_compression_toolkit.core.common.substitutions.batchnorm_folding import BatchNormalizationFolding, BatchNormalizationForwardFolding from model_compression_toolkit.core.keras.constants import KERNEL, LINEAR, ACTIVATION, DEPTHWISE_KERNEL, BIAS, GAMMA, BETA, \ MOVING_MEAN, MOVING_VARIANCE, EPSILON, USE_BIAS, LAYER_NAME, GROUPS +from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO def batchnorm_folding_node_matchers() -> [BaseNode, BaseNode]: @@ -76,10 +77,7 @@ def update_kernel_for_bn_folding_fn(conv_node: BaseNode, else: kernel = kernel * weights_scale.reshape((1, 1, 1, -1)) - if conv_node.type == DepthwiseConv2D: - kernel_name = DEPTHWISE_KERNEL - else: - kernel_name = KERNEL + kernel_name = DEFAULT_KERAS_INFO.get_kernel_op_attributes(conv_node.type)[0] return kernel, kernel_name @@ -110,10 +108,7 @@ def update_weights_for_bn_forward_folding_fn(conv_node: BaseNode, bias_update = (kernel * bias_factor.reshape((1, 1, -1, 1))).sum(2) kernel = kernel * weights_scale.reshape((1, 1, -1, 1)) - if conv_node.type == DepthwiseConv2D: - kernel_name = DEPTHWISE_KERNEL - else: - kernel_name = KERNEL + kernel_name = DEFAULT_KERAS_INFO.get_kernel_op_attributes(conv_node.type)[0] return kernel, bias + bias_update.flatten(), kernel_name diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py index d8aa0b825..de1085646 100644 --- a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py @@ -96,14 +96,17 @@ def substitute(self, scale_factor = threshold_float / threshold graph.user_info.set_input_scale(1 / scale_factor) - w1_fixed = linear_layer.get_weights_by_keys(KERNEL) * scale_factor - linear_layer.set_weights_by_keys(KERNEL, w1_fixed) + kernel_attr = graph.fw_info.get_kernel_op_attributes(linear_layer.type)[0] + + w1_fixed = linear_layer.get_weights_by_keys(kernel_attr) * scale_factor + linear_layer.set_weights_by_keys(kernel_attr, w1_fixed) graph.scale_stats_collector(input_layer, 1 / scale_factor) # After scaling weights may have different thresholds so it needs to be recalculated for nqc in linear_layer.candidates_quantization_cfg: - nqc.weights_quantization_cfg.calculate_and_set_weights_params(w1_fixed) + nqc.weights_quantization_cfg.get_attr_config(kernel_attr).calculate_and_set_weights_params(w1_fixed, + nqc.weights_quantization_cfg.min_threshold) return graph diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index 9340a659b..5e089dfb8 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -19,16 +19,17 @@ import tensorflow as tf from mct_quantizers import KerasQuantizationWrapper, KerasActivationQuantizationHolder from tensorflow.keras.models import Model -from tensorflow.python.layers.base import Layer from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoService from model_compression_toolkit.core.keras.hessian.activation_trace_hessian_calculator_keras import \ ActivationTraceHessianCalculatorKeras -from model_compression_toolkit.core.keras.hessian.trace_hessian_calculator_keras import TraceHessianCalculatorKeras from model_compression_toolkit.core.keras.hessian.weights_trace_hessian_calculator_keras import WeightsTraceHessianCalculatorKeras +from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ + get_inferable_quantizers +from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \ + get_weights_quantizer_for_node, get_activations_quantizer_for_node from model_compression_toolkit.logger import Logger -from model_compression_toolkit.trainable_infrastructure.keras.quantize_wrapper import KerasTrainableQuantizationWrapper from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import set_layer_to_bitwidth from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence, compute_cs, compute_mse @@ -586,3 +587,39 @@ def sensitivity_eval_inference(self, """ return model(inputs) + + def get_inferable_quantizers(self, node: BaseNode): + """ + Returns sets of Keras compatible weights and activation quantizers for the given node. + + Args: + node: Node to get quantizers for. + + Returns: + weight_quantizers: A dictionary between a weight's name to its quantizer. + activation_quantizers: A list of activations quantization, one for each layer output. + + """ + + def _weight_name(w: str) -> str: + """ + Extracts the weight name from the full TensorFlow variable name. + + For example, returns 'kernel' for 'dense_2/kernel:0'. + + Args: + w: TensorFlow variable name. + + Returns: + Extracted weight name. + """ + + return w.split(':')[0].split('/')[-1] + + attribute_names = [_weight_name(wn) for wn in node.get_node_weights_attributes() + if node.is_weights_quantization_enabled(wn)] + + return get_inferable_quantizers(node, + get_weights_quantizer_for_node, + get_activations_quantizer_for_node, + attribute_names) diff --git a/model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py b/model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py index f8a9dcfc4..c0039fa8d 100644 --- a/model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +++ b/model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py @@ -45,7 +45,8 @@ class ConfigurableActivationQuantizer(BaseKerasInferableQuantizer): def __init__(self, node_q_cfg: List[CandidateNodeQuantizationConfig], - max_candidate_idx: int = 0): + max_candidate_idx: int = 0, + kernel_attr: str = None): """ Initializes a configurable quantizer. @@ -53,13 +54,14 @@ def __init__(self, node_q_cfg: Quantization configuration candidates of the node that generated the layer that will use this quantizer. max_candidate_idx: Index of the node's candidate that has the maximal bitwidth (must exist absolute max). + kernel_attr: A kernel attribute name if the node have a kernel attribute (used only for candidates order validation). """ super(ConfigurableActivationQuantizer, self).__init__() self.node_q_cfg = node_q_cfg - verify_candidates_descending_order(self.node_q_cfg) + verify_candidates_descending_order(self.node_q_cfg, kernel_attr) for qc in node_q_cfg: if qc.activation_quantization_cfg.enable_activation_quantization != \ diff --git a/model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py b/model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py index 593f0238d..bb611baad 100644 --- a/model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +++ b/model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py @@ -51,6 +51,7 @@ class ConfigurableWeightsQuantizer(BaseKerasInferableQuantizer): def __init__(self, node_q_cfg: List[CandidateNodeQuantizationConfig], float_weights: tf.Tensor, + kernel_attr: str, max_candidate_idx: int = 0): """ Initializes a configurable quantizer. @@ -59,6 +60,7 @@ def __init__(self, node_q_cfg: Quantization configuration candidates of the node that generated the layer that will use this quantizer. float_weights: Float weights of the layer. + kernel_attr: The kernel attribute name of the node. Only layers with kernel op can be configured. max_candidate_idx: Index of the node's candidate that has the maximal bitwidth (must exist absolute max). """ @@ -68,19 +70,22 @@ def __init__(self, self.node_q_cfg = node_q_cfg self.float_weights = float_weights self.max_candidate_idx = max_candidate_idx + self.kernel_attr = kernel_attr - verify_candidates_descending_order(self.node_q_cfg) + verify_candidates_descending_order(self.node_q_cfg, kernel_attr) for qc in self.node_q_cfg: - if qc.weights_quantization_cfg.enable_weights_quantization != \ - self.node_q_cfg[0].weights_quantization_cfg.enable_weights_quantization: - Logger.error("Candidates with different weights enabled properties is currently not supported.") + if qc.weights_quantization_cfg.get_attr_config(self.kernel_attr).enable_weights_quantization != \ + self.node_q_cfg[0].weights_quantization_cfg.get_attr_config(self.kernel_attr).enable_weights_quantization: + Logger.error("Candidates with different kernel attribute quantization enabled " + "properties is currently not supported.") # Initialize quantized weights for each weight that should be quantized. self.quantized_weights = init_quantized_weights(node_q_cfg=self.node_q_cfg, float_weights=self.float_weights, fw_tensor_convert_func=partial(tf.convert_to_tensor, - dtype=tf.float32)) + dtype=tf.float32), + kernel_attr=self.kernel_attr) self.active_quantization_config_index = self.max_candidate_idx diff --git a/model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py index 4bc5b483e..d0690ab13 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py @@ -83,29 +83,37 @@ def mixed_precision_wrapper(self, """ - weights_conf_nodes_names = [n.name for n in self.graph.get_weights_configurable_nodes()] + weights_conf_nodes_names = [n.name for n in self.graph.get_weights_configurable_nodes(self.fw_info)] + kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0] + if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr): - if n.is_weights_quantization_enabled(): - kernel_attributes = self.fw_info.get_kernel_op_attributes(n.type) if n.name in weights_conf_nodes_names: return PytorchQuantizationWrapper(layer, - weights_quantizers={attr: ConfigurableWeightsQuantizer( - **self._get_weights_configurable_quantizer_kwargs(n, attr)) - for attr in kernel_attributes}) + weights_quantizers={ + kernel_attr: ConfigurableWeightsQuantizer( + **self._get_weights_configurable_quantizer_kwargs(n, + kernel_attr), + kernel_attr=kernel_attr)}) else: - node_weights_qc = n.get_unique_weights_candidates() + # TODO: Do we want to include other quantized attributes that are not + # the kernel attribute in the mixed precision model? + # Currently, we only consider kernel attribute quantization (whether it is in mixed precision + # or single precision). + node_weights_qc = n.get_unique_weights_candidates(kernel_attr) if not len(node_weights_qc) == 1: Logger.error(f"Expecting node {n.name} to have a unique weights configuration " # pragma: no cover f"but {len(node_weights_qc)} different configurations exist.") quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights, - node_weights_qc[0].weights_quantization_cfg.weights_quantization_method, + node_weights_qc[0].weights_quantization_cfg + .get_attr_config(kernel_attr) + .weights_quantization_method, BasePyTorchInferableQuantizer) - kwargs = get_weights_inferable_quantizer_kwargs(node_weights_qc[0].weights_quantization_cfg) + kwargs = get_weights_inferable_quantizer_kwargs(node_weights_qc[0].weights_quantization_cfg, + kernel_attr) return PytorchQuantizationWrapper(layer, - weights_quantizers={attr: quantier_for_node(**kwargs) - for attr in kernel_attributes}) + weights_quantizers={kernel_attr: quantier_for_node(**kwargs)}) return layer @@ -124,7 +132,7 @@ def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> assert n.candidates_quantization_cfg is not None, f"Node {n.name} candidates_quantization_cfg is None" node_q_cfg_candidates = n.candidates_quantization_cfg # sort by descending bit width so using indices would be easier - node_q_cfg_candidates.sort(key=lambda x: (x.weights_quantization_cfg.weights_n_bits, + node_q_cfg_candidates.sort(key=lambda x: (x.weights_quantization_cfg.get_attr_config(attr).weights_n_bits, x.activation_quantization_cfg.activation_n_bits), reverse=True) float_weights = n.get_weights_by_keys(attr) @@ -162,18 +170,22 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua if n.name in activation_conf_nodes_names: assert n.candidates_quantization_cfg is not None, f"Node {n.name} candidates_quantization_cfg is None" node_q_cfg_candidates = n.candidates_quantization_cfg - # sort by descending bit width so using indices would be easier - node_q_cfg_candidates.sort(key=lambda x: (x.weights_quantization_cfg.weights_n_bits, - x.activation_quantization_cfg.activation_n_bits), - reverse=True) + + # sorting the candidates by kernel attribute weights number of bits first and then by + # activation number of bits (in reversed order). + # since only kernel attribute is quantized in weights mixed precision, + # if the node doesn't have a kernel attribute, we only sort by activation_n_bits. + n.sort_node_candidates(self.fw_info) max_cfg_candidates = n.find_max_candidates_indices() assert len(max_cfg_candidates) == 1, \ f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates" max_candidate_idx = max_cfg_candidates[0] + kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0] activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates, - 'max_candidate_idx': max_candidate_idx})] \ + 'max_candidate_idx': max_candidate_idx, + 'kernel_attr': kernel_attr})] \ * num_of_outputs else: node_act_qc = n.get_unique_activation_candidates() @@ -207,7 +219,7 @@ def build_model(self) -> Tuple[torch.nn.Module, UserInformation, # creating a mapping between graph nodes and model's layers for mixed precision configurability model_layers = dict(model.named_children()) conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model_layers) - for n in self.graph.get_configurable_sorted_nodes()} + for n in self.graph.get_configurable_sorted_nodes(self.fw_info)} return model, user_info, conf_node2layers @@ -259,7 +271,9 @@ def _find_layers_in_model_by_node(self, n: BaseNode, named_layers: Dict[str, tor Returns: A list of layers that responsible for the node's quantization. """ - weights_quant = n.is_weights_quantization_enabled() + # Only layers with kernel op are considered weights configurable + kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0] + weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr) act_quant = n.is_activation_quantization_enabled() if weights_quant and not act_quant: diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index bff1a5258..1e9aeb855 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -277,7 +277,7 @@ def forward(self, """ node_to_output_tensors_dict = dict() node_to_output_tensors_dict_float = dict() - configurable_nodes = self.graph.get_configurable_sorted_nodes_names() + configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO) for node in self.node_sort: input_tensors = _build_input_tensors_list(node, self.graph, diff --git a/model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py b/model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py index 818b30262..c856bdf12 100644 --- a/model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +++ b/model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py @@ -44,7 +44,8 @@ class ConfigurableActivationQuantizer(BasePyTorchInferableQuantizer): def __init__(self, node_q_cfg: List[CandidateNodeQuantizationConfig], - max_candidate_idx: int = 0): + max_candidate_idx: int = 0, + kernel_attr: str = None): """ Initializes a configurable quantizer. @@ -52,13 +53,14 @@ def __init__(self, node_q_cfg: Quantization configuration candidates of the node that generated the layer that will use this quantizer. max_candidate_idx: Index of the node's candidate that has the maximal bitwidth (must exist absolute max). + kernel_attr: A kernel attribute name if the node have a kernel attribute (used only for candidates order validation). """ super(ConfigurableActivationQuantizer, self).__init__() self.node_q_cfg = node_q_cfg - verify_candidates_descending_order(self.node_q_cfg) + verify_candidates_descending_order(self.node_q_cfg, kernel_attr) for qc in self.node_q_cfg: if qc.activation_quantization_cfg.enable_activation_quantization != \ diff --git a/model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py b/model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py index 7c9cbceb3..585df69fe 100644 --- a/model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +++ b/model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py @@ -52,6 +52,7 @@ class ConfigurableWeightsQuantizer(BasePyTorchInferableQuantizer): def __init__(self, node_q_cfg: List[CandidateNodeQuantizationConfig], float_weights: torch.Tensor, + kernel_attr: str, max_candidate_idx: int = 0): """ Initializes a configurable quantizer. @@ -60,6 +61,7 @@ def __init__(self, node_q_cfg: Quantization configuration candidates of the node that generated the layer that will use this quantizer. float_weights: Float weights of the layer. + kernel_attr: The kernel attribute name of the node. Only layers with kernel op can be configured. max_candidate_idx: Index of the node's candidate that has the maximal bitwidth (must exist absolute max). """ @@ -68,18 +70,21 @@ def __init__(self, self.node_q_cfg = node_q_cfg self.float_weights = float_weights self.max_candidate_idx = max_candidate_idx + self.kernel_attr = kernel_attr - verify_candidates_descending_order(self.node_q_cfg) + verify_candidates_descending_order(self.node_q_cfg, kernel_attr) for qc in self.node_q_cfg: - if qc.weights_quantization_cfg.enable_weights_quantization != \ - self.node_q_cfg[0].weights_quantization_cfg.enable_weights_quantization: - Logger.error("Candidates with different weights enabled properties is currently not supported.") # pragma: no cover + if qc.weights_quantization_cfg.get_attr_config(self.kernel_attr).enable_weights_quantization != \ + self.node_q_cfg[0].weights_quantization_cfg.get_attr_config(self.kernel_attr).enable_weights_quantization: + Logger.error("Candidates with different kernel attribute quantization enabled " + "properties is currently not supported.") # Initialize quantized weights for each weight that should be quantized. self.quantized_weights = init_quantized_weights(node_q_cfg=self.node_q_cfg, float_weights=self.float_weights, - fw_tensor_convert_func=to_torch_tensor) + fw_tensor_convert_func=to_torch_tensor, + kernel_attr=kernel_attr) self.active_quantization_config_index = self.max_candidate_idx diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 2b9f2a3ec..1a36febfd 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -86,6 +86,10 @@ from model_compression_toolkit.core.pytorch.statistics_correction.apply_second_moment_correction import \ pytorch_apply_second_moment_correction from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model +from model_compression_toolkit.exporter.model_wrapper.fw_agnostic.get_inferable_quantizers import \ + get_inferable_quantizers +from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \ + get_weights_quantizer_for_node, get_activations_quantizer_for_node from model_compression_toolkit.logger import Logger @@ -536,3 +540,21 @@ def get_trace_hessian_calculator(self, input_images=input_images, fw_impl=self, num_iterations_for_approximation=num_iterations_for_approximation) + + def get_inferable_quantizers(self, node: BaseNode): + """ + Returns sets of Pytorch compatible weights and activation quantizers for the given node. + + Args: + node: Node to get quantizers for. + + Returns: + weight_quantizers: A dictionary between a weight's name to its quantizer. + activation_quantizers: A list of activations quantization, one for each layer output. + + """ + + return get_inferable_quantizers(node, + get_weights_quantizer_for_node, + get_activations_quantizer_for_node, + node.get_node_weights_attributes()) \ No newline at end of file diff --git a/model_compression_toolkit/core/quantization_prep_runner.py b/model_compression_toolkit/core/quantization_prep_runner.py index 83d8be961..e22db8186 100644 --- a/model_compression_toolkit/core/quantization_prep_runner.py +++ b/model_compression_toolkit/core/quantization_prep_runner.py @@ -86,9 +86,7 @@ def quantization_preparation_runner(graph: Graph, ###################################### # Calculate quantization params ###################################### - calculate_quantization_params(graph, - fw_info, - fw_impl=fw_impl) + calculate_quantization_params(graph) if tb_w is not None: tb_w.add_graph(graph, 'thresholds_selection') diff --git a/model_compression_toolkit/core/runner.py b/model_compression_toolkit/core/runner.py index 48e984e3c..7a8ad2110 100644 --- a/model_compression_toolkit/core/runner.py +++ b/model_compression_toolkit/core/runner.py @@ -141,7 +141,7 @@ def core_runner(in_model: Any, if target_kpi is not None: # Retrieve lists of tuples (node, node's final weights/activation bitwidth) - weights_conf_nodes_bitwidth = tg.get_final_weights_config() + weights_conf_nodes_bitwidth = tg.get_final_weights_config(fw_info) activation_conf_nodes_bitwidth = tg.get_final_activation_config() Logger.info( diff --git a/model_compression_toolkit/exporter/model_wrapper/fw_agnostic/__init__.py b/model_compression_toolkit/exporter/model_wrapper/fw_agnostic/__init__.py new file mode 100644 index 000000000..ea3047f32 --- /dev/null +++ b/model_compression_toolkit/exporter/model_wrapper/fw_agnostic/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== \ No newline at end of file diff --git a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py b/model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quantizers.py similarity index 64% rename from model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py rename to model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quantizers.py index 28e849cdf..fdda9f837 100644 --- a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py +++ b/model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quantizers.py @@ -12,36 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import Tuple, List, Dict +from typing import Dict, List, Tuple, Callable from model_compression_toolkit.core.common import BaseNode -from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO -from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \ - get_activations_quantizer_for_node, get_weights_quantizer_for_node -def get_quantization_quantizers(node: BaseNode) -> Tuple[Dict, List]: +def get_inferable_quantizers(node: BaseNode, + get_weights_quantizer_for_node: Callable, + get_activations_quantizer_for_node: Callable, + attributes_names: List[str] = []) -> Tuple[Dict, List]: """ Create quantizers to wrap a layer for its corresponding node. Args: node: Node to create quantizers for. + get_weights_quantizer_for_node: A function that returns weights quantizer for the node attributes. + get_activations_quantizer_for_node: A function that returns activation quantizer for the node activation tensor. + attributes_names: A potential list of attribute names to set weights quantizers to. Returns: weight_quantizers: A dictionary between a weight's name to its quantizer. activation_quantizers: A list of activations quantization, one for each layer output. - """ + weight_quantizers = {} - if node.is_weights_quantization_enabled(): - weight_attrs = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(node.type) - weight_quantizer = get_weights_quantizer_for_node(node) - for attr in weight_attrs: + activation_quantizers = [] + + for attr in attributes_names: + if node.is_weights_quantization_enabled(attr): + weight_quantizer = get_weights_quantizer_for_node(node, attr) weight_quantizers[attr] = weight_quantizer - activation_quantizers = [] if node.is_activation_quantization_enabled(): num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1 activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs - return weight_quantizers, activation_quantizers diff --git a/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py index ddcc9869c..28162d8e2 100644 --- a/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py @@ -19,17 +19,18 @@ from model_compression_toolkit.constants import FOUND_TF from model_compression_toolkit.core.common.user_info import UserInformation from model_compression_toolkit.logger import Logger -from mct_quantizers import KerasActivationQuantizationHolder +import model_compression_toolkit.core as C if FOUND_TF: import tensorflow as tf from tensorflow.keras.layers import Layer from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder - from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers from mct_quantizers import KerasQuantizationWrapper + from mct_quantizers import KerasActivationQuantizationHolder def _get_wrapper(node: common.BaseNode, - layer: Layer) -> Layer: + layer: Layer, + fw_impl=None) -> Layer: """ A function which takes a computational graph node and a keras layer and perform the quantization wrapping Args: @@ -39,14 +40,14 @@ def _get_wrapper(node: common.BaseNode, Returns: Wrapped layer with weights quantizers and activation quantizers """ - weights_quantizers, _ = get_quantization_quantizers(node) + weights_quantizers, _ = fw_impl.get_inferable_quantizers(node) if len(weights_quantizers) > 0: return KerasQuantizationWrapper(layer, weights_quantizers) return layer - def get_activation_quantizer_holder(node: common.BaseNode) -> Callable: + def get_activation_quantizer_holder(node: common.BaseNode, fw_impl) -> Callable: """ Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node. @@ -56,7 +57,7 @@ def get_activation_quantizer_holder(node: common.BaseNode) -> Callable: Returns: A ActivationQuantizationHolder layer for the node activation quantization. """ - _, activation_quantizers = get_quantization_quantizers(node) + _, activation_quantizers = fw_impl.get_inferable_quantizers(node) # Holder by definition uses a single quantizer for the activation quantization # thus we make sure this is the only possible case (unless it's a node with no activation @@ -68,8 +69,6 @@ def get_activation_quantizer_holder(node: common.BaseNode) -> Callable: f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers ' f'were found for node {node}') - - def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]: """ Convert graph to an exportable Keras model (model with all quantization parameters). @@ -83,8 +82,12 @@ def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, Use Exportable Keras model and user information. """ exportable_model, user_info = KerasModelBuilder(graph=graph, - wrapper=_get_wrapper, - get_activation_quantizer_holder_fn=get_activation_quantizer_holder).build_model() + wrapper=lambda n, kn: + _get_wrapper(n, kn, + fw_impl=C.keras.keras_implementation.KerasImplementation()), + get_activation_quantizer_holder_fn=lambda n: + get_activation_quantizer_holder(n, + fw_impl=C.keras.keras_implementation.KerasImplementation())).build_model() exportable_model.trainable = False return exportable_model, user_info else: diff --git a/model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py b/model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py index 142b2566e..7ed40d355 100644 --- a/model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +++ b/model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py @@ -28,13 +28,15 @@ def get_inferable_quantizer_kwargs(node_qc: BaseNodeQuantizationConfig, - quantization_target: QuantizationTarget) -> Dict[str, Any]: + quantization_target: QuantizationTarget, + attr_name: str = None) -> Dict[str, Any]: """ Get the quantization parameters for an inferable quantizer. Args: node_qc: The node quantization configuration of the node for which the quantizer is being created. Needs to match the specific quantization target. quantization_target: The target of the quantization (weights or activations). + attr_name: The weights attribute to get its quantizer kwargs (if target is weights quantization). Returns: The quantization parameters as a dictionary. @@ -44,33 +46,38 @@ def get_inferable_quantizer_kwargs(node_qc: BaseNodeQuantizationConfig, if not isinstance(node_qc, NodeWeightsQuantizationConfig): Logger.error(f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover - quantization_method = node_qc.weights_quantization_method + if attr_name is None: + Logger.error(f"Attribute name was not specified for retrieving weights quantizer kwargs.") + + attr_node_qc = node_qc.get_attr_config(attr_name=attr_name) + + quantization_method = attr_node_qc.weights_quantization_method # Return the appropriate quantization parameters based on the quantization method if quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]: - return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits, - qi_keras_consts.THRESHOLD: list(node_qc.weights_quantization_params[THRESHOLD].flatten()), - qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold, - qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis, - qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[THRESHOLD].shape)} + return {qi_keras_consts.NUM_BITS: attr_node_qc.weights_n_bits, + qi_keras_consts.THRESHOLD: list(attr_node_qc.weights_quantization_params[THRESHOLD].flatten()), + qi_keras_consts.PER_CHANNEL: attr_node_qc.weights_per_channel_threshold, + qi_keras_consts.CHANNEL_AXIS: attr_node_qc.weights_channels_axis[0], # output channel axis + qi_keras_consts.INPUT_RANK: len(attr_node_qc.weights_quantization_params[THRESHOLD].shape)} elif quantization_method in [QuantizationMethod.UNIFORM]: - return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits, - qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold, - qi_keras_consts.MIN_RANGE: list(node_qc.weights_quantization_params[RANGE_MIN].flatten()), - qi_keras_consts.MAX_RANGE: list(node_qc.weights_quantization_params[RANGE_MAX].flatten()), - qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis, - qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[RANGE_MIN].shape)} + return {qi_keras_consts.NUM_BITS: attr_node_qc.weights_n_bits, + qi_keras_consts.PER_CHANNEL: attr_node_qc.weights_per_channel_threshold, + qi_keras_consts.MIN_RANGE: list(attr_node_qc.weights_quantization_params[RANGE_MIN].flatten()), + qi_keras_consts.MAX_RANGE: list(attr_node_qc.weights_quantization_params[RANGE_MAX].flatten()), + qi_keras_consts.CHANNEL_AXIS: attr_node_qc.weights_channels_axis[0], # output channel axis + qi_keras_consts.INPUT_RANK: len(attr_node_qc.weights_quantization_params[RANGE_MIN].shape)} elif quantization_method in [QuantizationMethod.LUT_SYM_QUANTIZER, QuantizationMethod.LUT_POT_QUANTIZER]: - return {qi_keras_consts.NUM_BITS: node_qc.weights_n_bits, - qi_keras_consts.PER_CHANNEL: node_qc.weights_per_channel_threshold, - qi_keras_consts.LUT_VALUES: list(node_qc.weights_quantization_params[LUT_VALUES].flatten()), - qi_keras_consts.THRESHOLD: list(node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()), - qi_keras_consts.CHANNEL_AXIS: node_qc.weights_channels_axis, + return {qi_keras_consts.NUM_BITS: attr_node_qc.weights_n_bits, + qi_keras_consts.PER_CHANNEL: attr_node_qc.weights_per_channel_threshold, + qi_keras_consts.LUT_VALUES: list(attr_node_qc.weights_quantization_params[LUT_VALUES].flatten()), + qi_keras_consts.THRESHOLD: list(attr_node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten()), + qi_keras_consts.CHANNEL_AXIS: attr_node_qc.weights_channels_axis[0], # output channel axis # TODO: how to pass multiplier nbits and eps for a specific node? - qi_keras_consts.INPUT_RANK: len(node_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)} + qi_keras_consts.INPUT_RANK: len(attr_node_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)} else: Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover @@ -108,24 +115,28 @@ def get_inferable_quantizer_kwargs(node_qc: BaseNodeQuantizationConfig, Logger.critical(f'{quantization_target} is not supported') # pragma: no cover -def get_weights_quantizer_for_node(node: BaseNode) -> BaseKerasInferableQuantizer: +def get_weights_quantizer_for_node(node: BaseNode, attr_name: str) -> BaseKerasInferableQuantizer: """ - Get weights quantizer for a node. + Get weights quantizer for a weights attribute of a node. + Args: node: Node to create a weight quantizer for. + attr_name: Attribute name to get its quantizer. + Returns: - Quantizer for the node's weights. + Quantizer for the node's weights attribute. """ if node.final_weights_quantization_cfg is None: Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration') # pragma: # no cover node_w_qc = node.final_weights_quantization_cfg - weights_quantization_method = node_w_qc.weights_quantization_method + weights_quantization_method = node_w_qc.get_attr_config(attr_name).weights_quantization_method quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights, weights_quantization_method, BaseKerasInferableQuantizer) - kwargs = get_inferable_quantizer_kwargs(node_w_qc, QuantizationTarget.Weights) + + kwargs = get_inferable_quantizer_kwargs(node_w_qc, QuantizationTarget.Weights, attr_name) return quantier_for_node(**kwargs) diff --git a/model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py b/model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py deleted file mode 100644 index 92e61aaaa..000000000 --- a/model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizers.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from typing import Dict, List, Tuple -from model_compression_toolkit.core.common import BaseNode -from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO -from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \ - get_weights_quantizer_for_node, get_activations_quantizer_for_node - - -def get_quantization_quantizers(node: BaseNode) -> Tuple[Dict, List]: - """ - Create quantizers to wrap a layer for its corresponding node. - - Args: - node: Node to create quantizers for. - - Returns: - weight_quantizers: A dictionary between a weight's name to its quantizer. - activation_quantizers: A list of activations quantization, one for each layer output. - """ - weight_quantizers = {} - activation_quantizers = [] - - if node.is_weights_quantization_enabled(): - weight_attrs = DEFAULT_KERAS_INFO.get_kernel_op_attributes(node.type) - weight_quantizer = get_weights_quantizer_for_node(node) - for attr in weight_attrs: - weight_quantizers[attr] = weight_quantizer - - if node.is_activation_quantization_enabled(): - num_of_outputs = len(node.output_shape) if isinstance(node.output_shape, list) else 1 - activation_quantizers = [get_activations_quantizer_for_node(node)] * num_of_outputs - - return weight_quantizers, activation_quantizers diff --git a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py index dea2a5fd7..eb318c140 100644 --- a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py @@ -19,17 +19,17 @@ from model_compression_toolkit.constants import FOUND_TORCH from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common import BaseNode +import model_compression_toolkit.core as C if FOUND_TORCH: import torch from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder - from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizers import \ - get_quantization_quantizers def fully_quantized_wrapper(node: common.BaseNode, - module: torch.nn.Module) -> Union[torch.nn.Module,PytorchQuantizationWrapper]: + module: torch.nn.Module, + fw_impl) -> Union[torch.nn.Module,PytorchQuantizationWrapper]: """ A function which takes a computational graph node and a pytorch module and perform the quantization wrapping @@ -40,12 +40,12 @@ def fully_quantized_wrapper(node: common.BaseNode, Returns: Wrapped layer """ - weight_quantizers, _ = get_quantization_quantizers(node) + weight_quantizers, _ = fw_impl.get_inferable_quantizers(node) if len(weight_quantizers) > 0: return PytorchQuantizationWrapper(module, weight_quantizers) return module - def get_activation_quantizer_holder(node: BaseNode) -> Callable: + def get_activation_quantizer_holder(node: BaseNode, fw_impl) -> Callable: """ Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node. If the layer is not supposed to be wrapped with an activation quantizer - return None. @@ -54,7 +54,7 @@ def get_activation_quantizer_holder(node: BaseNode) -> Callable: Returns: A PytorchActivationQuantizationHolder module for the node's activation quantization. """ - _, activation_quantizers = get_quantization_quantizers(node) + _, activation_quantizers = fw_impl.get_inferable_quantizers(node) # Holder by definition uses a single quantizer for the activation quantization # thus we make sure this is the only possible case (unless it's a node we no activation # quantization, which in this case has an empty list). @@ -75,8 +75,12 @@ def get_exportable_pytorch_model(graph: Graph): Fully quantized PyTorch model. """ return PyTorchModelBuilder(graph=graph, - wrapper=fully_quantized_wrapper, - get_activation_quantizer_holder_fn=get_activation_quantizer_holder).build_model() + wrapper=lambda n, m: + fully_quantized_wrapper(n, m, + fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation()), + get_activation_quantizer_holder_fn=lambda n: + get_activation_quantizer_holder(n, + fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation())).build_model() else: def get_exportable_pytorch_model(*args, **kwargs): # pragma: no cover diff --git a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py index 85b248a51..9f38218f0 100644 --- a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +++ b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py @@ -30,12 +30,14 @@ import numpy as np -def get_weights_inferable_quantizer_kwargs(node_qc: NodeWeightsQuantizationConfig) -> Dict[str, Any]: +def get_weights_inferable_quantizer_kwargs(node_qc: NodeWeightsQuantizationConfig, attr_name: str) -> Dict[str, Any]: """ Get the quantization parameters for a weights inferable quantizer. Args: node_qc: The node quantization configuration of the node for which the quantizer is being created. Needs to match the specific quantization target. + attr_name: The weights attribute to get its quantizer kwargs (if target is weights quantization). + Returns: The quantization parameters as a dictionary. @@ -45,30 +47,37 @@ def get_weights_inferable_quantizer_kwargs(node_qc: NodeWeightsQuantizationConfi Logger.error( f"Non-compatible node quantization config was given for quantization target Weights.") # pragma: no cover - quantization_method = node_qc.weights_quantization_method + if attr_name is None: + Logger.error(f"Attribute name was not specified for retrieving weights quantizer kwargs.") + + attr_node_qc = node_qc.get_attr_config(attr_name=attr_name) + + quantization_method = attr_node_qc.weights_quantization_method # Return the appropriate quantization parameters based on the quantization method if quantization_method in [QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC]: - return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits, - qi_inferable_quantizers_constants.THRESHOLD: node_qc.weights_quantization_params[THRESHOLD].flatten().tolist(), - qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold, - qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis} + return {qi_inferable_quantizers_constants.NUM_BITS: attr_node_qc.weights_n_bits, + qi_inferable_quantizers_constants.THRESHOLD: attr_node_qc.weights_quantization_params[THRESHOLD].flatten().tolist(), + qi_inferable_quantizers_constants.PER_CHANNEL: attr_node_qc.weights_per_channel_threshold, + qi_inferable_quantizers_constants.CHANNEL_AXIS: attr_node_qc.weights_channels_axis[0], # output channel axis + } elif quantization_method in [QuantizationMethod.UNIFORM]: - return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits, - qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold, - qi_inferable_quantizers_constants.MIN_RANGE: node_qc.weights_quantization_params[RANGE_MIN].flatten().tolist(), - qi_inferable_quantizers_constants.MAX_RANGE: node_qc.weights_quantization_params[RANGE_MAX].flatten().tolist(), - qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis} + return {qi_inferable_quantizers_constants.NUM_BITS: attr_node_qc.weights_n_bits, + qi_inferable_quantizers_constants.PER_CHANNEL: attr_node_qc.weights_per_channel_threshold, + qi_inferable_quantizers_constants.MIN_RANGE: attr_node_qc.weights_quantization_params[RANGE_MIN].flatten().tolist(), + qi_inferable_quantizers_constants.MAX_RANGE: attr_node_qc.weights_quantization_params[RANGE_MAX].flatten().tolist(), + qi_inferable_quantizers_constants.CHANNEL_AXIS: attr_node_qc.weights_channels_axis[0], # output channel axis + } elif quantization_method in [QuantizationMethod.LUT_POT_QUANTIZER, QuantizationMethod.LUT_SYM_QUANTIZER]: - return {qi_inferable_quantizers_constants.NUM_BITS: node_qc.weights_n_bits, - qi_inferable_quantizers_constants.LUT_VALUES: node_qc.weights_quantization_params[LUT_VALUES].flatten().tolist(), - qi_inferable_quantizers_constants.THRESHOLD: node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten().tolist(), - qi_inferable_quantizers_constants.PER_CHANNEL: node_qc.weights_per_channel_threshold, - qi_inferable_quantizers_constants.CHANNEL_AXIS: node_qc.weights_channels_axis, - qi_inferable_quantizers_constants.INPUT_RANK: len(node_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)} + return {qi_inferable_quantizers_constants.NUM_BITS: attr_node_qc.weights_n_bits, + qi_inferable_quantizers_constants.LUT_VALUES: attr_node_qc.weights_quantization_params[LUT_VALUES].flatten().tolist(), + qi_inferable_quantizers_constants.THRESHOLD: attr_node_qc.weights_quantization_params[SCALE_PER_CHANNEL].flatten().tolist(), + qi_inferable_quantizers_constants.PER_CHANNEL: attr_node_qc.weights_per_channel_threshold, + qi_inferable_quantizers_constants.CHANNEL_AXIS: attr_node_qc.weights_channels_axis[0], # output channel axis + qi_inferable_quantizers_constants.INPUT_RANK: len(attr_node_qc.weights_quantization_params[SCALE_PER_CHANNEL].shape)} # TODO: Add LUT_VALUES_BITWIDTH & EPS to node quantization config else: @@ -115,12 +124,13 @@ def get_activation_inferable_quantizer_kwargs(node_qc: NodeActivationQuantizatio Logger.critical(f'Not supported quantization method for inferable quantizers.') # pragma: no cover -def get_weights_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuantizer: +def get_weights_quantizer_for_node(node: BaseNode, attr_name: str) -> BasePyTorchInferableQuantizer: """ Get weights quantizer for a node. Args: node: Node to create a weight quantizer for. + attr_name: Attribute name to get its quantizer. Returns: Quantizer for the node's weights. @@ -130,12 +140,12 @@ def get_weights_quantizer_for_node(node: BaseNode) -> BasePyTorchInferableQuanti Logger.critical(f'Can not set quantizer for a node with no final weights quantization configuration') # pragma: # no cover node_w_qc = node.final_weights_quantization_cfg - weights_quantization_method = node_w_qc.weights_quantization_method + weights_quantization_method = node_w_qc.get_attr_config(attr_name).weights_quantization_method quantier_for_node = get_inferable_quantizer_class(QuantizationTarget.Weights, weights_quantization_method, BasePyTorchInferableQuantizer) - kwargs = get_weights_inferable_quantizer_kwargs(node_w_qc) + kwargs = get_weights_inferable_quantizer_kwargs(node_w_qc, attr_name) return quantier_for_node(**kwargs) diff --git a/model_compression_toolkit/gptq/common/gptq_graph.py b/model_compression_toolkit/gptq/common/gptq_graph.py index 13ef6b2cc..2beedb8ac 100644 --- a/model_compression_toolkit/gptq/common/gptq_graph.py +++ b/model_compression_toolkit/gptq/common/gptq_graph.py @@ -39,7 +39,9 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L compare_points_std = [] compare_points_name = [] for n in input_graph.get_topo_sorted_nodes(): - if len(n.weights) > 0 and n.is_weights_quantization_enabled() and not n.reuse: + # only nodes with kernel attribute are currently trained with GPTQ and are used as compare points + kernel_attr = input_graph.fw_info.get_kernel_op_attributes(n.type)[0] + if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr) and not n.reuse: compare_points.append(n) compare_points_name.append(n.name) compare_points_std.append(n.prior_info.std_output) diff --git a/model_compression_toolkit/gptq/keras/gptq_training.py b/model_compression_toolkit/gptq/keras/gptq_training.py index ec2acfcd1..1d731623a 100644 --- a/model_compression_toolkit/gptq/keras/gptq_training.py +++ b/model_compression_toolkit/gptq/keras/gptq_training.py @@ -129,11 +129,8 @@ def _is_gptq_weights_trainable(self, Returns: A boolean whether the layer is to be wrapped with a QuantizeWrapper """ - - if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type): - Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} " - f"without a kernel isn't supported") - return node.is_weights_quantization_enabled() + kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0] + return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr) def gptq_wrapper(self, n: common.BaseNode, @@ -149,11 +146,16 @@ def gptq_wrapper(self, """ if self._is_gptq_weights_trainable(n): + # If we are here, then the node has a kernel attribute to quantize and training during GPTQ weights_quantizers, _ = quantization_builder(n, - self.gptq_config) # TODO: split quantizers building into two functions: for weights and activations + self.gptq_config, # TODO: split quantizers building into two functions: for weights and activations + self.fw_info.get_kernel_op_attributes(n.type)[0]) if len(weights_quantizers) > 0: return KerasTrainableQuantizationWrapper(layer, - weights_quantizers=weights_quantizers) + weights_quantizers=weights_quantizers) + + # TODO: need to check if in this case, if there are other weights attributes that are not trainable but are + # quantized, do we need to wrap them as well? return layer def get_activation_quantizer_holder(self, n: common.BaseNode) -> Callable: diff --git a/model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py b/model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py index ad40d9ff5..e7cb69afe 100644 --- a/model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +++ b/model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py @@ -16,10 +16,8 @@ from model_compression_toolkit.gptq import GradientPTQConfig from model_compression_toolkit.core import common -from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \ get_inferable_quantizer_kwargs -from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq from model_compression_toolkit.gptq.keras.quantizer.base_keras_gptq_quantizer import BaseKerasGPTQTrainableQuantizer from mct_quantizers import QuantizationTarget from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class @@ -33,15 +31,16 @@ def quantization_builder(n: common.BaseNode, - gptq_config: GradientPTQConfig - ) -> Tuple[Dict[str, BaseKerasGPTQTrainableQuantizer], List[BaseKerasInferableQuantizer]]: + gptq_config: GradientPTQConfig, + kernel_attr: str = None) -> Tuple[Dict[str, BaseKerasGPTQTrainableQuantizer], List[BaseKerasInferableQuantizer]]: """ Build quantizers for a node according to its quantization configuration and a global NoOpQuantizeConfig object. Args: n: Node to build its QuantizeConfig. - gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration. + gptq_config (GradientPTQConfig): GradientPTQConfig configuration. + kernel_attr: A potential kernel attribute name to build its trainable quantizer. Returns: A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training. @@ -50,18 +49,19 @@ def quantization_builder(n: common.BaseNode, """ weights_quantizers = {} - if n.is_weights_quantization_enabled(): - quant_method = n.final_weights_quantization_cfg.weights_quantization_method + + if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr): + # Only nodes with kernel attribute are trainable during GPTQ + quant_method = n.final_weights_quantization_cfg.get_attr_config(kernel_attr).weights_quantization_method quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Weights, quantizer_id=gptq_config.rounding_type, quant_method=quant_method, quantizer_base_class=BaseKerasGPTQTrainableQuantizer) - kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=n.type, - fw_info=DEFAULT_KERAS_INFO) - weights_quantizers.update({kernel_attribute: quantizer_class(get_trainable_quantizer_weights_config(n), - **gptq_config.gptq_quantizer_params_override)}) + weights_quantizers.update({kernel_attr: quantizer_class(get_trainable_quantizer_weights_config(n, + kernel_attr), + **gptq_config.gptq_quantizer_params_override)}) activation_quantizers = [] if n.is_activation_quantization_enabled(): diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index e20ee2c37..d53d3cdec 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -103,16 +103,16 @@ def _is_gptq_weights_trainable(self, node: BaseNode) -> bool: """ A function for deciding if a layer should be fine-tuned during GPTQ. + Args: node (BaseNode): Node for quantization decision + Returns: A boolean whether the layer is to be wrapped with a Quantization Wrapper. """ - if node.is_weights_quantization_enabled() and not self.fw_info.is_kernel_op(node.type): - Logger.error(f"GPTQ Error: Quantizing node {node.name} of type {node.type} " - f"without a kernel isn't supported.") - return node.is_weights_quantization_enabled() + kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0] + return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr) def gptq_wrapper(self, n: BaseNode, @@ -128,11 +128,18 @@ def gptq_wrapper(self, """ if self._is_gptq_weights_trainable(n): - weights_quantizers, activation_quantizers = quantization_builder(n, self.gptq_config) - return PytorchQuantizationWrapper(layer, - weights_quantizers=weights_quantizers) - else: - return layer + # If we are here, then the node has a kernel attribute to quantize and training during GPTQ + weights_quantizers, _ = quantization_builder(n, + self.gptq_config, + self.fw_info.get_kernel_op_attributes(n.type)[0]) + + if len(weights_quantizers) > 0: + return PytorchQuantizationWrapper(layer, + weights_quantizers=weights_quantizers) + + # TODO: need to check if in this case, if there are other weights attributes that are not trainable but are + # quantized, do we need to wrap them as well? + return layer def get_activation_quantizer_holder(self, n: BaseNode) -> Callable: """ diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py b/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py index 4af79afca..2723c5ec7 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py @@ -16,7 +16,6 @@ from model_compression_toolkit.gptq import GradientPTQConfig from model_compression_toolkit.core import common -from model_compression_toolkit.core.pytorch.constants import KERNEL from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \ get_activation_inferable_quantizer_kwargs from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \ @@ -35,15 +34,16 @@ def quantization_builder(n: common.BaseNode, gptq_config: GradientPTQConfig, - ) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer], - List[BasePyTorchInferableQuantizer]]: + kernel_attr: str = None + ) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer], List[BasePyTorchInferableQuantizer]]: """ Build quantizers for a node according to its quantization configuration and a global NoOpQuantizeConfig object. Args: n: Node to build its QuantizeConfig. - gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration. + gptq_config (GradientPTQConfig): GradientPTQConfig configuration. + kernel_attr: A potential kernel attribute name to build its trainable quantizer. Returns: A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training. @@ -52,14 +52,16 @@ def quantization_builder(n: common.BaseNode, """ weights_quantizers = {} - if n.is_weights_quantization_enabled(): - quant_method = n.final_weights_quantization_cfg.weights_quantization_method + if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr): + # Only nodes with kernel attribute are trainable during GPTQ + quant_method = n.final_weights_quantization_cfg.get_attr_config(kernel_attr).weights_quantization_method quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Weights, quantizer_id=gptq_config.rounding_type, quant_method=quant_method, quantizer_base_class=BasePytorchGPTQTrainableQuantizer) - weights_quantizers.update({KERNEL: quantizer_class(get_trainable_quantizer_weights_config(n), - **gptq_config.gptq_quantizer_params_override)}) + weights_quantizers.update({kernel_attr: quantizer_class(get_trainable_quantizer_weights_config(n, + kernel_attr), + **gptq_config.gptq_quantizer_params_override)}) activation_quantizers = [] if n.is_activation_quantization_enabled(): if n.final_activation_quantization_cfg is None: diff --git a/model_compression_toolkit/qat/common/qat_config.py b/model_compression_toolkit/qat/common/qat_config.py index 1817191c4..08c7b7229 100644 --- a/model_compression_toolkit/qat/common/qat_config.py +++ b/model_compression_toolkit/qat/common/qat_config.py @@ -24,6 +24,7 @@ def is_qat_applicable(node: common.BaseNode, fw_info: FrameworkInfo) -> bool: """ A function for deciding if a layer should be fine-tuned during QAT + Args: node (BaseNode): Node for quantization decision fw_info (FrameworkInfo): Pytorch quantization information @@ -32,9 +33,10 @@ def is_qat_applicable(node: common.BaseNode, A boolean whether the layer is to be wrapped with a QuantizeWrapper """ - if node.is_weights_quantization_enabled() and not fw_info.is_kernel_op(node.type): - Logger.error("QAT Error: Quantizing a node without a kernel isn't supported") - return node.is_weights_quantization_enabled() or node.is_activation_quantization_enabled() + kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0] + return (kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)) \ + or node.is_activation_quantization_enabled() + class TrainingMethod(Enum): diff --git a/model_compression_toolkit/qat/keras/quantization_facade.py b/model_compression_toolkit/qat/keras/quantization_facade.py index cc419678e..540c41c35 100644 --- a/model_compression_toolkit/qat/keras/quantization_facade.py +++ b/model_compression_toolkit/qat/keras/quantization_facade.py @@ -72,12 +72,16 @@ def qat_wrapper(n: common.BaseNode, """ if is_qat_applicable(n, DEFAULT_KERAS_INFO): + # If we are here, then the node has a kernel attribute to quantize and training during QAT weights_quantizers, _ = quantization_builder(n, qat_config, - DEFAULT_KERAS_INFO) + DEFAULT_KERAS_INFO.get_kernel_op_attributes(n.type)[0]) if len(weights_quantizers) > 0: layer.trainable = True return KerasTrainableQuantizationWrapper(layer, weights_quantizers) + + # TODO: need to check if in this case, if there are other weights attributes that are not trainable but are + # quantized, do we need to wrap them as well? return layer diff --git a/model_compression_toolkit/qat/keras/quantizer/quantization_builder.py b/model_compression_toolkit/qat/keras/quantizer/quantization_builder.py index 375b334ad..880931e9c 100644 --- a/model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +++ b/model_compression_toolkit/qat/keras/quantizer/quantization_builder.py @@ -42,8 +42,7 @@ def get_activation_quantizer_holder(n: common.BaseNode, A KerasActivationQuantizationHolder layer for the node activation quantization. """ _, activation_quantizers = quantization_builder(n, - qat_config, - DEFAULT_KERAS_INFO) + qat_config) # Holder by definition uses a single quantizer for the activation quantization # thus we make sure this is the only possible case (unless it's a node with no activation @@ -55,7 +54,7 @@ def get_activation_quantizer_holder(n: common.BaseNode, def quantization_builder(n: common.BaseNode, qat_config: QATConfig, - fw_info: FrameworkInfo, + kernel_attr: str = None, ) -> Tuple[Dict[str, BaseKerasQATTrainableQuantizer], List[BaseKerasQATTrainableQuantizer]]: """ Build quantizers for a node according to its quantization configuration. @@ -63,29 +62,32 @@ def quantization_builder(n: common.BaseNode, Args: n: Node to build its QuantizeConfig. qat_config (QATConfig): QAT configuration - fw_info: Framework information (e.g., mapping from layers to their attributes to quantize). + kernel_attr: A potential kernel attribute name to build its trainable quantizer. + Returns: weights_quantizers: A dictionary between a weight's name to its quantizer. activation_quantizers: A list of activations quantization, one for each layer output. """ if len(n.candidates_quantization_cfg) > 1: - wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n) + wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n, kernel_attr) else: wq_cand, aq_cand = None, None weight_quantizers = {} - if n.is_weights_quantization_enabled(): - quant_method = n.final_weights_quantization_cfg.weights_quantization_method + if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr): + # Only nodes with kernel attribute are trainable during QAT + quant_method = n.final_weights_quantization_cfg.get_attr_config(kernel_attr).weights_quantization_method quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights, qat_config.weight_training_method, quant_method, BaseKerasQATTrainableQuantizer) - attributes = fw_info.get_kernel_op_attributes(n.type) - for attr in attributes: - weight_quantizers.update({attr: quantizer_class(get_trainable_quantizer_weights_config(n, wq_cand), - **qat_config.weight_quantizer_params_override)}) + + weight_quantizers.update({kernel_attr: quantizer_class(get_trainable_quantizer_weights_config(n, + attr_name=kernel_attr, + weights_quantization_candidates=wq_cand), + **qat_config.weight_quantizer_params_override)}) activation_quantizers = [] if n.is_activation_quantization_enabled(): diff --git a/model_compression_toolkit/qat/pytorch/quantization_facade.py b/model_compression_toolkit/qat/pytorch/quantization_facade.py index cc6ee7165..d204a7361 100644 --- a/model_compression_toolkit/qat/pytorch/quantization_facade.py +++ b/model_compression_toolkit/qat/pytorch/quantization_facade.py @@ -62,9 +62,14 @@ def qat_wrapper(n: common.BaseNode, """ if is_qat_applicable(n, DEFAULT_PYTORCH_INFO): - weights_quantizers, _ = quantization_builder(n, qat_config, DEFAULT_PYTORCH_INFO) + # If we are here, then the node has a kernel attribute to quantize and training during QAT + weights_quantizers, _ = quantization_builder(n, qat_config, + DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(n.type)[0]) if len(weights_quantizers) > 0: return PytorchQuantizationWrapper(module, weights_quantizers) + + # TODO: need to check if in this case, if there are other weights attributes that are not trainable but are + # quantized, do we need to wrap them as well? return module diff --git a/model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py b/model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py index 00ffcaf21..0645512c8 100644 --- a/model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +++ b/model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py @@ -42,8 +42,7 @@ def get_activation_quantizer_holder(n: common.BaseNode, A ActivationQuantizationHolder layer for the node's activation quantization. """ _, activation_quantizers = quantization_builder(n, - qat_config, - DEFAULT_PYTORCH_INFO) + qat_config) # Holder by definition uses a single quantizer for the activation quantization # thus we make sure this is the only possible case (unless it's a node with no activation @@ -55,7 +54,7 @@ def get_activation_quantizer_holder(n: common.BaseNode, def quantization_builder(n: common.BaseNode, qat_config: QATConfig, - fw_info: FrameworkInfo, + kernel_attr: str = None, ) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer], List[BasePytorchQATTrainableQuantizer]]: """ @@ -64,28 +63,31 @@ def quantization_builder(n: common.BaseNode, Args: n: Node to build its QuantizeConfig. qat_config (QATConfig): QAT configuration - fw_info: Framework information (e.g., mapping from layers to their attributes to quantize). + kernel_attr: A potential kernel attribute name to build its trainable quantizer. Returns: weights_quantizers: A dictionary between a weight's name to its quantizer. activation_quantizers: A list of activations quantization, one for each layer output.). """ + if len(n.candidates_quantization_cfg) > 1: - wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n) + wq_cand, aq_cand = get_trainable_quantizer_quantization_candidates(n, kernel_attr) else: wq_cand, aq_cand = None, None weight_quantizers = {} - if n.is_weights_quantization_enabled(): - quant_method = n.final_weights_quantization_cfg.weights_quantization_method + if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr): + # Only nodes with kernel attribute are trainable during QAT + quant_method = n.final_weights_quantization_cfg.get_attr_config(kernel_attr).weights_quantization_method quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights, qat_config.weight_training_method, quant_method, BasePytorchQATTrainableQuantizer) - attributes = fw_info.get_kernel_op_attributes(n.type) - for attr in attributes: - weight_quantizers.update({attr: quantizer_class(get_trainable_quantizer_weights_config(n, wq_cand), - **qat_config.weight_quantizer_params_override)}) + + weight_quantizers.update({kernel_attr: quantizer_class(get_trainable_quantizer_weights_config(n, + attr_name=kernel_attr, + weights_quantization_candidates=wq_cand), + **qat_config.weight_quantizer_params_override)}) activation_quantizers = [] if n.is_activation_quantization_enabled(): diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py index a7214de43..78c3082c6 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py @@ -17,6 +17,7 @@ from typing import List, Dict, Union, Any from mct_quantizers import QuantizationMethod +from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.logger import Logger @@ -46,11 +47,11 @@ class AttributeQuantizationConfig: Hold the quantization configuration of a weight attribute of a layer. """ def __init__(self, - weights_quantization_method: QuantizationMethod, - weights_n_bits: int, - weights_per_channel_threshold: bool, - enable_weights_quantization: bool, - lut_values_bitwidth: Union[int, None], # If None - set 8 in hptq, o.w use it + weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO, + weights_n_bits: int = FLOAT_BITWIDTH, + weights_per_channel_threshold: bool = False, + enable_weights_quantization: bool = False, + lut_values_bitwidth: Union[int, None] = None, # If None - set 8 in hptq, o.w use it ): """ Initializes an attribute quantization config. diff --git a/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py b/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py index 3430786e8..244f05cf1 100644 --- a/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +++ b/model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py @@ -21,6 +21,7 @@ def get_trainable_quantizer_weights_config( n: BaseNode, + attr_name: str, weights_quantization_candidates: List[TrainableQuantizerCandidateConfig] = None ) -> TrainableQuantizerWeightsConfig: """ @@ -28,6 +29,7 @@ def get_trainable_quantizer_weights_config( Args: n: BaseNode - the node to build a trainable quantizer from. + attr_name: Attribute name to get its weights quantizer configuration. weights_quantization_candidates: A list of weights quantizer config candidates. Returns: @@ -36,14 +38,15 @@ def get_trainable_quantizer_weights_config( if n.final_weights_quantization_cfg is None: Logger.error(f'Node must have final_weights_quantization_cfg in order to build quantizer configuration') # pragma: no cover - final_cfg = n.final_weights_quantization_cfg - return TrainableQuantizerWeightsConfig(final_cfg.weights_quantization_method, - final_cfg.weights_n_bits, - final_cfg.weights_quantization_params, - final_cfg.enable_weights_quantization, - final_cfg.weights_channels_axis, - final_cfg.weights_per_channel_threshold, - final_cfg.min_threshold, + final_node_cfg = n.final_weights_quantization_cfg + final_attr_cfg = final_node_cfg.get_attr_config(attr_name) + return TrainableQuantizerWeightsConfig(final_attr_cfg.weights_quantization_method, + final_attr_cfg.weights_n_bits, + final_attr_cfg.weights_quantization_params, + final_attr_cfg.enable_weights_quantization, + final_attr_cfg.weights_channels_axis[0], # Output channel axis + final_attr_cfg.weights_per_channel_threshold, + final_node_cfg.min_threshold, weights_quantization_candidates) @@ -73,41 +76,44 @@ def get_trainable_quantizer_activation_config( activation_quantization_candidates) -def get_trainable_quantizer_quantization_candidates(n: BaseNode): +def get_trainable_quantizer_quantization_candidates(n: BaseNode, attr: str = None): """ Returns quantization configuration candidates for activation and weights trainable quantizer. Checks that the candidates are compatible with trainable quantizer Args: n: BaseNode - the node to build a trainable quantizer from + attr: Weights attribute to get its quantization configuration candidates and trainable quantizer. Returns: weights_quantization_candidates - A list of configuration candidates for weights activation_quantization_candidates - A list of configuration candidates for activation """ - # all candidates must have the same weights quantization method - weights_quantization_methods = set([cfg.weights_quantization_cfg.weights_quantization_method for cfg in n.candidates_quantization_cfg]) - if len(weights_quantization_methods) > 1: - Logger.error(f'Unsupported candidates_quantization_cfg with different weights quantization methods: {weights_quantization_methods}') # pragma: no cover + + if attr is not None: + # all candidates must have the same weights quantization method + weights_quantization_methods = set([cfg.weights_quantization_cfg.get_attr_config(attr).weights_quantization_method + for cfg in n.candidates_quantization_cfg]) + if len(weights_quantization_methods) > 1: + Logger.error(f'Unsupported candidates_quantization_cfg with different weights quantization methods: ' + f'{weights_quantization_methods}') # pragma: no cover # all candidates must have the same activation quantization method - activation_quantization_methods = set([cfg.activation_quantization_cfg.activation_quantization_method for cfg in n.candidates_quantization_cfg]) + activation_quantization_methods = set([cfg.activation_quantization_cfg.activation_quantization_method + for cfg in n.candidates_quantization_cfg]) if len(activation_quantization_methods) > 1: - Logger.error(f'Unsupported candidates_quantization_cfg with different activation quantization methods: {activation_quantization_methods}') # pragma: no cover + Logger.error(f'Unsupported candidates_quantization_cfg with different activation quantization methods: ' + f'{activation_quantization_methods}') # pragma: no cover # get unique lists of candidates - unique_weights_candidates = n.get_unique_weights_candidates() + unique_weights_candidates = n.get_unique_weights_candidates(attr) unique_activation_candidates = n.get_unique_activation_candidates() - # verify all the combinations of weights_n_bits and activation_n_bits are allowed - if len(n.candidates_quantization_cfg) != len(unique_weights_candidates) * len(unique_activation_candidates): - Logger.error(f'Unsupported candidates_quantization_cfg for a trainable quantizer,' - f'it must contain all the combinations of (weights_n_bits X activations_n_bits)') # pragma: no cover - # generate list of weights quantizer candidates weights_cfg_candidates = [TrainableQuantizerCandidateConfig( - cfg.weights_quantization_cfg.weights_n_bits, - cfg.weights_quantization_cfg.weights_quantization_params) for cfg in unique_weights_candidates] + cfg.weights_quantization_cfg.get_attr_config(attr).weights_n_bits, + cfg.weights_quantization_cfg.get_attr_config(attr).weights_quantization_params) + for cfg in unique_weights_candidates] # generate list of activation quantizer candidates activation_cfg_candidates = [TrainableQuantizerCandidateConfig( diff --git a/tests/common_tests/helpers/prep_graph_for_func_test.py b/tests/common_tests/helpers/prep_graph_for_func_test.py index 13b2f8bbc..35e31135a 100644 --- a/tests/common_tests/helpers/prep_graph_for_func_test.py +++ b/tests/common_tests/helpers/prep_graph_for_func_test.py @@ -85,9 +85,7 @@ def prepare_graph_with_quantization_parameters(in_model, for i in range(10): mi.infer([np.random.randn(*input_shape)]) - calculate_quantization_params(graph, - fw_info=fw_info, - fw_impl=fw_impl) + calculate_quantization_params(graph) return graph diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/lut_quantizer.py b/tests/keras_tests/feature_networks_tests/feature_networks/lut_quantizer.py index d3f4aa6ad..6462f300c 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/lut_quantizer.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/lut_quantizer.py @@ -65,9 +65,12 @@ def get_tpc(self): return generate_keras_tpc(name='lut_quantizer_test', tp_model=tp_model) def get_debug_config(self): - return mct.core.DebugConfig(network_editor=[EditRule(filter=NodeNameFilter(self.node_to_change_name), - action=ChangeCandidatesWeightsQuantizationMethod( - weights_quantization_method=mct.target_platform.QuantizationMethod.POWER_OF_TWO))]) + return mct.core.DebugConfig( + network_editor=[EditRule(filter=NodeNameFilter(self.node_to_change_name), + action=ChangeCandidatesWeightsQuantizationMethod( + weights_quantization_method= + mct.target_platform.QuantizationMethod.POWER_OF_TWO, + attr_name=KERNEL))]) def get_input_shapes(self): return [[self.val_batch_size, 16, 16, self.num_conv_channels]] diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/change_qc_attr_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/change_qc_attr_test.py index e4d260f98..ad90e70b4 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/change_qc_attr_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/change_qc_attr_test.py @@ -20,6 +20,7 @@ from model_compression_toolkit.core.common.network_editors.actions import EditRule, ChangeFinalWeightsQuantConfigAttr, \ ChangeFinalActivationQuantConfigAttr, ChangeCandidatesActivationQuantConfigAttr from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter +from model_compression_toolkit.core.keras.constants import KERNEL from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest from tests.keras_tests.utils import get_layers_from_model_by_type diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py index 4ec8c422b..90a3188b6 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/edit_qc_test.py @@ -33,6 +33,7 @@ statistics_correction_runner from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner +from model_compression_toolkit.core.keras.constants import KERNEL from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod @@ -93,9 +94,7 @@ def prepare_graph_for_second_network_editor(in_model, representative_data_gen, c ###################################### # Calculate quantization params ###################################### - calculate_quantization_params(transformed_graph, - fw_info, - fw_impl=fw_impl) + calculate_quantization_params(transformed_graph) if tb_w is not None: tb_w.add_graph(transformed_graph, 'thresholds_selection') @@ -285,10 +284,10 @@ def run_test(self, experimental_exporter=False): for node in filtered_nodes: if node.final_weights_quantization_cfg is not None and hasattr(self.action, 'weights_quantization_method'): - self.unit_test.assertTrue(node.final_weights_quantization_cfg.weights_quantization_method - == self.action.weights_quantization_method) + self.unit_test.assertTrue(node.final_weights_quantization_cfg.get_attr_config(KERNEL) + .weights_quantization_method == self.action.weights_quantization_method) elif node.final_activation_quantization_cfg is not None and hasattr(self.action, - 'activation_quantization_method'): + 'activation_quantization_method'): self.unit_test.assertTrue(node.final_activation_quantization_cfg.activation_quantization_method == self.action.activation_quantization_method) else: @@ -297,8 +296,9 @@ def run_test(self, experimental_exporter=False): self.unit_test.assertTrue(nqc.activation_quantization_cfg.activation_quantization_method == self.action.activation_quantization_method) if hasattr(self.action, 'weights_quantization_method'): - self.unit_test.assertTrue(nqc.weights_quantization_cfg.weights_quantization_method - == self.action.weights_quantization_method) + self.unit_test.assertTrue(nqc.weights_quantization_cfg.get_attr_config(KERNEL) + .weights_quantization_method == + self.action.weights_quantization_method) class ChangeCandidatesActivationQuantizationMethodQCAttrTest(BaseChangeQuantizationMethodQCAttrTest): @@ -314,7 +314,8 @@ class ChangeCandidatesWeightsQuantizationMethodQCAttrTest(BaseChangeQuantization def __init__(self, unit_test): edit_filter = NodeTypeFilter(layers.Conv2D) - action = ChangeCandidatesWeightsQuantizationMethod(weights_quantization_method=QuantizationMethod.UNIFORM) + action = ChangeCandidatesWeightsQuantizationMethod(attr_name=KERNEL, + weights_quantization_method=QuantizationMethod.UNIFORM) prepare_graph_func = prepare_graph_for_first_network_editor super().__init__(unit_test, edit_filter=edit_filter, action=action, prepare_graph_func=prepare_graph_func) @@ -332,6 +333,7 @@ class ChangeFinalsWeightsQuantizationMethodQCAttrTest(BaseChangeQuantizationMeth def __init__(self, unit_test): edit_filter = NodeTypeFilter(layers.Conv2D) - action = ChangeFinalWeightsQuantizationMethod(weights_quantization_method=QuantizationMethod.UNIFORM) + action = ChangeFinalWeightsQuantizationMethod(attr_name=KERNEL, + weights_quantization_method=QuantizationMethod.UNIFORM) prepare_graph_func = prepare_graph_for_second_network_editor super().__init__(unit_test, edit_filter=edit_filter, action=action, prepare_graph_func=prepare_graph_func) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py index 22d77295f..516db39e5 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py @@ -23,6 +23,7 @@ NodeTypeFilter from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \ get_weights_quantization_params_fn, get_activation_quantization_params_fn +from model_compression_toolkit.core.keras.constants import KERNEL from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import generate_keras_tpc from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest @@ -73,11 +74,14 @@ def get_debug_config(self): network_editor = [EditRule(filter=NodeNameScopeFilter(self.scope), action=ChangeCandidatesActivationQuantConfigAttr(activation_n_bits=self.activation_n_bits)), EditRule(filter=NodeNameScopeFilter(self.scope), - action=ChangeCandidatesWeightsQuantConfigAttr(weights_n_bits=self.weights_n_bits)), - EditRule(filter=NodeNameScopeFilter('2'), - action=ChangeCandidatesWeightsQuantConfigAttr(enable_weights_quantization=True)), - EditRule(filter=NodeNameScopeFilter('2') or NodeNameScopeFilter('does_not_exist'), - action=ChangeCandidatesWeightsQuantConfigAttr(enable_weights_quantization=False)) + action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, + weights_n_bits=self.weights_n_bits)), + EditRule(filter=NodeNameScopeFilter('change_2'), + action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, + enable_weights_quantization=True)), + EditRule(filter=NodeNameScopeFilter('change_2') or NodeNameScopeFilter('does_not_exist'), + action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, + enable_weights_quantization=False)) ] return mct.core.DebugConfig(network_editor=network_editor) @@ -87,8 +91,8 @@ def get_input_shapes(self): def create_networks(self): inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) x = layers.Conv2D(self.num_conv_channels, self.kernel, use_bias=False, name='unchanged')(inputs) - x = layers.Conv2D(self.num_conv_channels, self.kernel, use_bias=False, name=self.scope + '_1')(x) - x = layers.Conv2D(self.num_conv_channels, self.kernel, use_bias=False, name=self.scope + '_2')(x) + x = layers.Conv2D(self.num_conv_channels, self.kernel, use_bias=False, name=self.scope + 'change_1')(x) + x = layers.Conv2D(self.num_conv_channels, self.kernel, use_bias=False, name=self.scope + 'change_2')(x) outputs = layers.Conv2D(self.num_conv_channels, self.kernel, use_bias=False)(x) model = keras.Model(inputs=inputs, outputs=outputs) @@ -149,7 +153,8 @@ def get_debug_config(self): network_editor = [EditRule(filter=NodeNameFilter(self.node_to_change_name), action=ChangeCandidatesActivationQuantConfigAttr(activation_n_bits=self.activation_n_bits)), EditRule(filter=NodeNameFilter(self.node_to_change_name), - action=ChangeCandidatesWeightsQuantConfigAttr(weights_n_bits=self.weights_n_bits)) + action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, + weights_n_bits=self.weights_n_bits)) ] return mct.core.DebugConfig(network_editor=network_editor) @@ -218,11 +223,13 @@ def get_quantization_config(self): def get_debug_config(self): network_editor = [EditRule(filter=NodeTypeFilter(self.type_to_change), - action=ChangeCandidatesWeightsQuantConfigAttr(weights_n_bits=self.weights_n_bits)), + action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, + weights_n_bits=self.weights_n_bits)), EditRule(filter=NodeTypeFilter(self.type_to_change), action=ChangeCandidatesActivationQuantConfigAttr(activation_n_bits=self.activation_n_bits)), EditRule(filter=NodeTypeFilter(self.type_to_change).__and__(NodeNameFilter(self.node_to_change_name)), - action=ChangeQuantizationParamFunction(weights_quantization_params_fn=self.weights_params_fn())), + action=ChangeQuantizationParamFunction(attr_name=KERNEL, + weights_quantization_params_fn=self.weights_params_fn())), EditRule(filter=NodeTypeFilter(self.type_to_change).__and__(NodeNameFilter(self.node_to_change_name)), action=ChangeQuantizationParamFunction( activation_quantization_params_fn=self.activations_params_fn())), diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/test_kmeans_quantizer.py b/tests/keras_tests/feature_networks_tests/feature_networks/test_kmeans_quantizer.py index 1b34d1b4e..b1ee10feb 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/test_kmeans_quantizer.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/test_kmeans_quantizer.py @@ -93,12 +93,15 @@ def create_networks(self): def get_debug_config(self): return mct.core.DebugConfig(network_editor=[EditRule(filter=NodeNameFilter(self.node_to_change_name), action=ChangeCandidatesWeightsQuantConfigAttr( + attr_name=KERNEL, weights_quantization_method=target_platform.QuantizationMethod.POWER_OF_TWO)), EditRule(filter=NodeNameFilter(self.node_to_change_name), action=ChangeCandidatesWeightsQuantConfigAttr( + attr_name=KERNEL, weights_quantization_fn=power_of_two_quantizer)), EditRule(filter=NodeNameFilter(self.node_to_change_name), action=ChangeCandidatesWeightsQuantConfigAttr( + attr_name=KERNEL, weights_quantization_params_fn=power_of_two_selection_tensor)), ]) diff --git a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py index 717bc3ecd..732a1aa68 100644 --- a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py +++ b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py @@ -24,6 +24,7 @@ from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ set_quantization_configuration_to_graph +from model_compression_toolkit.core.keras.constants import KERNEL from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation from model_compression_toolkit.core.common.fusion.layer_fusing import fusion @@ -96,7 +97,7 @@ def test_cfg_filter_activation_only_nodes(self): # Filtering nodes; candidates filtered_graph = filter_nodes_candidates(graph) - filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes() + filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes(DEFAULT_KERAS_INFO) # checking that layers with activation only (input and relu) have filtered configurations list, # that they have a configuration for each of the original bitwidth options @@ -125,7 +126,7 @@ def test_cfg_filter_weights_disabled(self): # Filtering nodes; candidates filtered_graph = filter_nodes_candidates(graph) - filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes() + filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes(DEFAULT_KERAS_INFO) # checking that layers with weights (conv2d) have filtered activation configurations list # when weights quantization is disabled @@ -154,14 +155,14 @@ def test_cfg_filter_activation_disabled(self): # Filtering nodes; candidates filtered_graph = filter_nodes_candidates(graph) - filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes() + filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes(DEFAULT_KERAS_INFO) # checking that layers with weights (conv2d) have filtered weights configurations list # when activation quantization is disabled - conv2d_candidates = filtered_configurable_nodes[0].candidates_quantization_cfg - self.assertTrue(len(conv2d_candidates) == 3, - f"Expects 3 Conv layer candidates, number of candidates is {len(conv2d_candidates)}") - self.assertTrue([c.weights_quantization_cfg.weights_n_bits for c in conv2d_candidates] == [8, 4, 2]) + conv2d_kernel_candidates = filtered_configurable_nodes[0].get_all_weights_attr_candidates(KERNEL) + self.assertTrue(len(conv2d_kernel_candidates) == 3, + f"Expects 3 Conv layer kernel candidates, number of candidates is {len(conv2d_kernel_candidates)}") + self.assertTrue([c.weights_n_bits for c in conv2d_kernel_candidates] == [8, 4, 2]) def test_cfg_filter_multiple_candidates_weights_disabled(self): input_shape = (8, 8, 3) @@ -178,7 +179,7 @@ def test_cfg_filter_multiple_candidates_weights_disabled(self): # Filtering nodes; candidates filtered_graph = filter_nodes_candidates(graph) - filtered_graph_nodes = list(filtered_graph.nodes) + filtered_graph_nodes = filtered_graph.get_topo_sorted_nodes() # checking that layers with weights (conv2d) have filtered weights configurations list # when activation quantization is disabled @@ -186,8 +187,8 @@ def test_cfg_filter_multiple_candidates_weights_disabled(self): self.assertTrue(len(conv2d_candidates) == 1, f"Expects 1 Conv layer candidates, number of candidates is {len(conv2d_candidates)}") candidate = conv2d_candidates[0] - self.assertTrue((candidate.weights_quantization_cfg.weights_n_bits, - candidate.activation_quantization_cfg.activation_n_bits) == (FLOAT_BITWIDTH , 8)) + self.assertTrue((candidate.weights_quantization_cfg.get_attr_config(KERNEL).weights_n_bits, + candidate.activation_quantization_cfg.activation_n_bits) == (FLOAT_BITWIDTH, 8)) if __name__ == '__main__': diff --git a/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py b/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py index 405120299..0e33b7853 100644 --- a/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py +++ b/tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py @@ -29,8 +29,9 @@ ConfigurableActivationQuantizer from model_compression_toolkit.core.keras.mixed_precision.configurable_weights_quantizer import \ ConfigurableWeightsQuantizer -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_quantization_parameters +from tests.keras_tests.exporter_tests.tflite_int8.imx500_int8_tp_model import get_op_quantization_configs +from tests.keras_tests.tpc_keras import get_tpc_with_activation_mp_keras, get_weights_only_mp_tpc_keras def base_model(input_shape): @@ -43,12 +44,14 @@ def representative_dataset(): yield [np.random.randn(1, 8, 8, 3).astype(np.float32)] -def test_setup(): +def test_setup(get_tpc_fn): model = base_model((8, 8, 3)) graph = prepare_graph_with_quantization_parameters(model, KerasImplementation(), DEFAULT_KERAS_INFO, - representative_dataset, generate_keras_tpc, input_shape=(1, 8, 8, 3)) + representative_dataset, get_tpc_fn, + input_shape=(1, 8, 8, 3), + mixed_precision_enabled=True) layer = model.layers[1] node = graph.get_topo_sorted_nodes()[1] @@ -59,7 +62,16 @@ def test_setup(): class TestKerasSetLayerToBitwidth(unittest.TestCase): def test_set_layer_to_bitwidth_weights(self): - layer, node = test_setup() + base_config, _, default_config = get_op_quantization_configs() + tpc = get_weights_only_mp_tpc_keras( + base_config=base_config, + default_config=default_config, + mp_bitwidth_candidates_list=[(8, 8), (4, 8), (2, 8)], + name='set_layer_test_tpc') + + # In this test we need a dedicated TPC so we just override the TPC generator function that needed to be passed + # to the tests preparation helper method + layer, node = test_setup(get_tpc_fn=lambda x, y: tpc) wrapper_layer = \ KerasTrainableQuantizationWrapper(layer, @@ -67,7 +79,8 @@ def test_set_layer_to_bitwidth_weights(self): ConfigurableWeightsQuantizer( node_q_cfg=node.candidates_quantization_cfg, float_weights=node.get_weights_by_keys(KERNEL), - max_candidate_idx=node.find_max_candidates_indices()[0]) + max_candidate_idx=node.find_max_candidates_indices()[0], + kernel_attr=KERNEL) }) for attr, q in wrapper_layer.weights_quantizers.items(): @@ -84,7 +97,16 @@ def test_set_layer_to_bitwidth_weights(self): self.assertEqual(q.active_quantization_config_index, 0) def test_set_layer_to_bitwidth_activation(self): - layer, node = test_setup() + base_config, _, default_config = get_op_quantization_configs() + tpc = get_tpc_with_activation_mp_keras( + base_config=base_config, + default_config=default_config, + mp_bitwidth_candidates_list=[(8, 8), (8, 4), (8, 2)], + name='set_layer_test_tpc') + + # In this test we need a dedicated TPC so we just override the TPC generator function that needed to be passed + # to the tests preparation helper method + layer, node = test_setup(get_tpc_fn=lambda x, y: tpc) holder_layer = \ KerasActivationQuantizationHolder(ConfigurableActivationQuantizer( diff --git a/tests/keras_tests/function_tests/test_symmetric_threshold_selection_weights.py b/tests/keras_tests/function_tests/test_symmetric_threshold_selection_weights.py index d6c40f1e2..8530223a8 100644 --- a/tests/keras_tests/function_tests/test_symmetric_threshold_selection_weights.py +++ b/tests/keras_tests/function_tests/test_symmetric_threshold_selection_weights.py @@ -21,6 +21,7 @@ import model_compression_toolkit as mct from model_compression_toolkit.core import QuantizationConfig, QuantizationErrorMethod from model_compression_toolkit.constants import THRESHOLD +from model_compression_toolkit.core.keras.constants import KERNEL from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation @@ -102,8 +103,8 @@ def run_test_for_threshold_method(self, threshold_method, per_channel=True): qc=qc, input_shape=(1, 16, 16, 4)) nodes_list = list(graph.nodes) - conv1_threshold = nodes_list[0].candidates_quantization_cfg[0].weights_quantization_cfg.weights_quantization_params[THRESHOLD] - conv2_threshold = nodes_list[1].candidates_quantization_cfg[0].weights_quantization_cfg.weights_quantization_params[THRESHOLD] + conv1_threshold = nodes_list[0].candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_quantization_params[THRESHOLD] + conv2_threshold = nodes_list[1].candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_quantization_params[THRESHOLD] conv1_threshold_log = np.log2(conv1_threshold) conv2_threshold_log = np.log2(conv2_threshold) self.assertFalse(np.array_equal(conv1_threshold_log, conv1_threshold_log.astype(int)), diff --git a/tests/keras_tests/function_tests/test_uniform_range_selection_weights.py b/tests/keras_tests/function_tests/test_uniform_range_selection_weights.py index 4f7c044bb..2d9cadf56 100644 --- a/tests/keras_tests/function_tests/test_uniform_range_selection_weights.py +++ b/tests/keras_tests/function_tests/test_uniform_range_selection_weights.py @@ -21,6 +21,7 @@ import model_compression_toolkit as mct from model_compression_toolkit.core import QuantizationConfig, QuantizationErrorMethod from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX +from model_compression_toolkit.core.keras.constants import KERNEL from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation @@ -101,10 +102,10 @@ def run_test_for_threshold_method(self, threshold_method, per_channel=True): qc=qc, input_shape=(1, 16, 16, 4)) nodes_list = list(graph.nodes) - conv1_min = nodes_list[0].candidates_quantization_cfg[0].weights_quantization_cfg.weights_quantization_params[RANGE_MIN].flatten() - conv2_min = nodes_list[1].candidates_quantization_cfg[0].weights_quantization_cfg.weights_quantization_params[RANGE_MIN].flatten() - conv1_max = nodes_list[0].candidates_quantization_cfg[0].weights_quantization_cfg.weights_quantization_params[RANGE_MAX].flatten() - conv2_max = nodes_list[1].candidates_quantization_cfg[0].weights_quantization_cfg.weights_quantization_params[RANGE_MAX].flatten() + conv1_min = nodes_list[0].candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_quantization_params[RANGE_MIN].flatten() + conv2_min = nodes_list[1].candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_quantization_params[RANGE_MIN].flatten() + conv1_max = nodes_list[0].candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_quantization_params[RANGE_MAX].flatten() + conv2_max = nodes_list[1].candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_quantization_params[RANGE_MAX].flatten() for range_min, range_max in list(zip(conv1_min, conv1_max)): self.assertTrue(range_min <= 0 <= range_max, diff --git a/tests/keras_tests/layer_tests/test_layers_runner.py b/tests/keras_tests/layer_tests/test_layers_runner.py index 2c02475f7..02edffc6b 100644 --- a/tests/keras_tests/layer_tests/test_layers_runner.py +++ b/tests/keras_tests/layer_tests/test_layers_runner.py @@ -245,12 +245,15 @@ def test_crop_and_resize(self): boxes = tf.random.uniform(shape=(5, 4)) box_indices = tf.random.uniform(shape=(5,), minval=0, maxval=1, dtype=tf.int32) + + # TODO: Exporting layers with constant weights is not supported. Enable exporter once feature is supported. BaseKerasLayerTest(self, [partial(tf.image.crop_and_resize, boxes=boxes, box_indices=box_indices, crop_size=(22, 19)), partial(tf.image.crop_and_resize, boxes=boxes, box_indices=box_indices, crop_size=(21, 24), method='nearest'), partial(tf.image.crop_and_resize, boxes=boxes, box_indices=box_indices, crop_size=(24, 20), - extrapolation_value=0)]).run_test() + extrapolation_value=0)], + experimental_exporter=False).run_test() def test_conv2dtranspose(self): BaseKerasLayerTest(self, diff --git a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py index 26626a6ad..59ea34f98 100644 --- a/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py +++ b/tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py @@ -244,9 +244,8 @@ def dummy_representative_dataset(): def representative_data_gen(): yield [np.random.random(input_shape)] - calculate_quantization_params(graph, - fw_info, - fw_impl=keras_impl) + calculate_quantization_params(graph) + keras_impl.get_sensitivity_evaluator(graph, core_config.mixed_precision_config, representative_data_gen, diff --git a/tests/pytorch_tests/function_tests/set_layer_to_bitwidth_test.py b/tests/pytorch_tests/function_tests/set_layer_to_bitwidth_test.py index ce421f758..cbafbadcb 100644 --- a/tests/pytorch_tests/function_tests/set_layer_to_bitwidth_test.py +++ b/tests/pytorch_tests/function_tests/set_layer_to_bitwidth_test.py @@ -25,9 +25,12 @@ from model_compression_toolkit.core.pytorch.mixed_precision.configurable_weights_quantizer import \ ConfigurableWeightsQuantizer from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs +from tests.common_tests.helpers.generate_test_tp_model import generate_mixed_precision_test_tp_model from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_quantization_parameters from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest +from tests.pytorch_tests.tpc_pytorch import get_pytorch_test_tpc_dict + class base_model(torch.nn.Module): @@ -41,10 +44,10 @@ def forward(self, inp): return x -def test_setup(representative_data_gen): +def test_setup(representative_data_gen, get_tpc_fn): model = base_model() graph = prepare_graph_with_quantization_parameters(model, PytorchImplementation(), DEFAULT_PYTORCH_INFO, - representative_data_gen, generate_pytorch_tpc, + representative_data_gen, get_tpc_fn, input_shape=(1, 3, 8, 8), mixed_precision_enabled=True) @@ -68,13 +71,26 @@ def representative_data_gen(self, n_iters=1): yield self.generate_inputs(input_shapes) def run_test(self, seed=0, **kwargs): - node, layer = test_setup(self.representative_data_gen) + base_config, _, default_config = get_op_quantization_configs() + tpc = get_pytorch_test_tpc_dict( + tp_model=generate_mixed_precision_test_tp_model( + base_cfg=base_config, + default_config=default_config, + mp_bitwidth_candidates_list=[(8, 8), (4, 8), (2, 8)]), + test_name='set_layer_bit_tests', + ftp_name='set_layer_bit_tests')['set_layer_bit_tests'] + + # In this test we need a dedicated TPC so we just override the TPC generator function that needed to be passed + # to the tests preparation helper method + node, layer = test_setup(self.representative_data_gen, get_tpc_fn=lambda x, y: tpc) + wrapper_layer = PytorchQuantizationWrapper(layer, weights_quantizers={KERNEL: ConfigurableWeightsQuantizer( node_q_cfg=node.candidates_quantization_cfg, float_weights=node.get_weights_by_keys(KERNEL), - max_candidate_idx=node.find_max_candidates_indices()[0]) + max_candidate_idx=node.find_max_candidates_indices()[0], + kernel_attr=KERNEL) }) for attr, q in wrapper_layer.weights_quantizers.items(): @@ -105,7 +121,19 @@ def representative_data_gen(self, n_iters=1): yield self.generate_inputs(input_shapes) def run_test(self, seed=0, **kwargs): - node, layer = test_setup(self.representative_data_gen) + base_config, _, default_config = get_op_quantization_configs() + tpc = get_pytorch_test_tpc_dict( + tp_model=generate_mixed_precision_test_tp_model( + base_cfg=base_config, + default_config=default_config, + mp_bitwidth_candidates_list=[(8, 8), (8, 4), (8, 2)]), + test_name='set_layer_bit_tests', + ftp_name='set_layer_bit_tests')['set_layer_bit_tests'] + + # In this test we need a dedicated TPC so we just override the TPC generator function that needed to be passed + # to the tests preparation helper method + node, layer = test_setup(self.representative_data_gen, get_tpc_fn=lambda x, y: tpc) + holder_layer = \ PytorchActivationQuantizationHolder(ConfigurableActivationQuantizer( node_q_cfg=node.candidates_quantization_cfg, diff --git a/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py b/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py index 012eb26aa..e077fc88d 100644 --- a/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py +++ b/tests/pytorch_tests/layer_tests/base_pytorch_layer_test.py @@ -252,6 +252,7 @@ def __compare_8bits_quantization_mode(self, float_model, quantized_model, quanti float_layer_name = str(node.target).split('.')[0] float_weights = get_layer_weights(getattr(float_model, float_layer_name)) for k, v in quantized_weights.items(): + # TODO: remove use of kernel op dict if k in fw_info.kernel_ops_attributes_mapping.get(type(op)): float_weight = float_weights.get(k) self.unit_test.assertFalse(float_weight is None) diff --git a/tests/pytorch_tests/model_tests/feature_models/lut_quantizer_test.py b/tests/pytorch_tests/model_tests/feature_models/lut_quantizer_test.py index 05850102f..86e64b3e0 100644 --- a/tests/pytorch_tests/model_tests/feature_models/lut_quantizer_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/lut_quantizer_test.py @@ -18,6 +18,7 @@ from model_compression_toolkit.core.common.network_editors.node_filters import NodeNameFilter from model_compression_toolkit.core.common.network_editors.actions import EditRule, \ ChangeCandidatesWeightsQuantizationMethod +from model_compression_toolkit.core.pytorch.constants import KERNEL from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model from tests.pytorch_tests.tpc_pytorch import get_pytorch_test_tpc_dict @@ -96,6 +97,7 @@ def get_tpc(self): def get_core_configs(self): network_editor = [EditRule(filter=NodeNameFilter(self.node_to_change_name), action=ChangeCandidatesWeightsQuantizationMethod( + attr_name=KERNEL, weights_quantization_method=self.quant_method))] return {'lut_quantizer_test': mct.core.CoreConfig(quantization_config=mct.core.QuantizationConfig( mct.core.QuantizationErrorMethod.MSE, mct.core.QuantizationErrorMethod.MSE),