Skip to content

Commit

Permalink
Support 16bit activations (#1138)
Browse files Browse the repository at this point in the history
Refactor TPC with supported input bit-width to each operator.
  • Loading branch information
elad-c authored Aug 1, 2024
1 parent f68485b commit 79f5098
Show file tree
Hide file tree
Showing 39 changed files with 1,342 additions and 506 deletions.
3 changes: 2 additions & 1 deletion model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@
# that are shared among different candidates:
WEIGHTS_NBITS_ATTRIBUTE = 'weights_n_bits'
CORRECTED_BIAS_ATTRIBUTE = 'corrected_bias'
ACTIVATION_NBITS_ATTRIBUTE = 'activation_n_bits'
ACTIVATION_N_BITS_ATTRIBUTE = 'activation_n_bits'
SUPPORTED_INPUT_ACTIVATION_NBITS_ATTRIBUTE = 'supported_input_activation_n_bits'

# Quantization Parameters Iterative Search Defaults:
SYMMETRIC_TENSOR_N_ITER = 40
Expand Down
55 changes: 50 additions & 5 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import numpy as np

from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
ACTIVATION_NBITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \
TargetPlatformCapabilities, LayerFilterParams
TargetPlatformCapabilities, LayerFilterParams, OpQuantizationConfig


class BaseNode:
Expand Down Expand Up @@ -297,7 +297,6 @@ def get_memory_bytes(self, fw_info) -> float:

return memory


def get_unified_weights_candidates_dict(self, fw_info) -> Dict[str, Any]:
"""
In Mixed-Precision, a node's kernel can have multiple candidates for weights quantization configuration.
Expand Down Expand Up @@ -343,7 +342,7 @@ def get_unified_activation_candidates_dict(self) -> Dict[str, Any]:
Returns: A dictionary containing information from node's activation quantization configuration candidates.
"""
shared_attributes = [ACTIVATION_NBITS_ATTRIBUTE]
shared_attributes = [ACTIVATION_N_BITS_ATTRIBUTE]
attr = dict()
if self.is_activation_quantization_enabled():
attr = copy.deepcopy(self.candidates_quantization_cfg[0].activation_quantization_cfg.__dict__)
Expand Down Expand Up @@ -539,7 +538,7 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions:
to the mappings from layers/LayerFilterParams to the OperatorsSet in the TargetPlatformModel.
Args:
tpc: TPC to extract the QuantizationConfigOptions for the node
tpc: TPC to extract the QuantizationConfigOptions for the node.
Returns:
QuantizationConfigOptions of the node.
Expand All @@ -559,6 +558,52 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions:
return matching_qcos[0]
return tpc.tp_model.default_qco

def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
next_nodes: List, node_qc_options: QuantizationConfigOptions
) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
"""
Filter quantization config options that don't match the graph.
A node may have several quantization config options with 'activation_n_bits' values, and
the next nodes in the graph may support different bit-width as input activation. This function
filters out quantization config that don't comply to these attributes.
Args:
tpc: TPC to extract the QuantizationConfigOptions for the next nodes.
next_nodes: Output nodes of current node.
node_qc_options: Node's QuantizationConfigOptions.
Returns:
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
if _option.activation_n_bits <= next_nodes_supported_input_bitwidth]
if len(_node_qc_options) == 0:
Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Logger.warning(f"Node {self} base quantization config changed to match Graph and TPC configuration.\nCause: {self} -> {next_nodes}.")
else:
Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover

return _base_config, _node_qc_options

def is_match_type(self, _type: Type) -> bool:
"""
Check if input type matches the node type, either in instance type or in type name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ def compute_total_bops(graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkI


def requires_mixed_precision(in_model: Any,
target_resource_utilization: ResourceUtilization,
representative_data_gen: Callable,
core_config: CoreConfig,
tpc: TargetPlatformCapabilities,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation) -> bool:
target_resource_utilization: ResourceUtilization,
representative_data_gen: Callable,
core_config: CoreConfig,
tpc: TargetPlatformCapabilities,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation) -> bool:
"""
The function checks whether the model requires mixed precision to meet the requested target resource utilization.
This is determined by whether the target memory usage of the weights is less than the available memory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self,
self.activation_n_bits = op_cfg.activation_n_bits
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
self.enable_activation_quantization = op_cfg.enable_activation_quantization
self.is_signed = op_cfg.is_signed
self.activation_channel_equalization = qc.activation_channel_equalization
self.input_scaling = qc.input_scaling
self.min_threshold = qc.min_threshold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import model_compression_toolkit.core.common.quantization.quantization_config as qc
from model_compression_toolkit.constants import LUT_VALUES, MIN_THRESHOLD, SCALE_PER_CHANNEL, \
LUT_VALUES_BITWIDTH, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
LUT_VALUES_BITWIDTH, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import \
max_power_of_two, int_quantization_with_threshold
Expand Down Expand Up @@ -110,7 +110,8 @@ def lut_kmeans_histogram(bins: np.ndarray,
constrained: bool = True,
n_iter: int = 20,
min_threshold: float = MIN_THRESHOLD,
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> Dict:
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
is_signed: bool = None) -> Dict:
"""
Finds quantization cluster points for non-uniform activation quantization.
The quantizer first finds the closest power-of-two number to the max value of the given histogram,
Expand All @@ -129,6 +130,7 @@ def lut_kmeans_histogram(bins: np.ndarray,
n_iter: Number of iteration ot search for the threshold (not used for this method).
min_threshold: Minimal threshold to use if threshold is too small.
quant_error_method: an error function to optimize the parameters' selection accordingly (not used for this method).
is_signed: Whether the quantization is signed or not. If None then compute SIGNED value.
Returns:
A dictionary containing the cluster assignments according to the k-means algorithm and
Expand All @@ -148,9 +150,9 @@ def lut_kmeans_histogram(bins: np.ndarray,
tensor_max = np.max(bins_with_values)
threshold = max_power_of_two(tensor_max, min_threshold)

signed = np.any(bins[:-1][counts != 0] < 0) # Whether histogram contains negative values or not.
signed = np.any(bins[:-1][counts != 0] < 0) if is_signed is None else is_signed # Whether histogram contains negative values or not.
tensor_for_kmeans = int_quantization_with_threshold(data=bins, threshold=threshold, n_bits=LUT_VALUES_BITWIDTH, signed=signed)
kmeans.fit(tensor_for_kmeans.reshape(-1, 1), sample_weight=np.insert(counts, 0, 0))

return {LUT_VALUES: np.float32(np.round(kmeans.cluster_centers_)),
THRESHOLD: threshold}
THRESHOLD: threshold, SIGNED: signed}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Union, Tuple, Dict

import model_compression_toolkit.core.common.quantization.quantization_config as qc
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES, SIGNED
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_search import \
qparams_selection_tensor_search, qparams_selection_histogram_search
Expand Down Expand Up @@ -105,7 +105,8 @@ def power_of_two_selection_histogram(bins: np.ndarray,
constrained: bool = True,
n_iter: int = 20,
min_threshold: float = MIN_THRESHOLD,
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE) -> dict:
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
is_signed: bool = None) -> Dict:
"""
Compute the power of two threshold based on the provided QuantizationErrorMethod to quantize a histogram.
Different search is applied, depends on the value of the selected QuantizationErrorMethod.
Expand All @@ -121,24 +122,28 @@ def power_of_two_selection_histogram(bins: np.ndarray,
n_iter: Number of iteration ot search for the threshold (not used for this method).
min_threshold: Minimal threshold to use if threshold is too small (used only for kl threshold selection).
quant_error_method: an error function to optimize the parameters' selection accordingly.
is_signed: Whether the quantization is signed or not. If None then compute SIGNED value.
Returns:
Power of two threshold to quantize the histogram a power of 2 manner.
"""
if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
tensor_max = np.max(np.abs(bins)[1:][counts > 0])
threshold = max_power_of_two(tensor_max, min_threshold)
# Resolve is_signed in case it is None.
signed = (bins<0).any() if is_signed is None else is_signed
else:
error_function = get_threshold_selection_histogram_error_function(QuantizationMethod.POWER_OF_TWO,
quant_error_method, p)
threshold = qparams_selection_histogram_search(error_function,
bins,
counts,
n_bits,
constrained=constrained,
n_iter=n_iter,
min_threshold=min_threshold)
return {THRESHOLD: threshold}
threshold, signed = qparams_selection_histogram_search(error_function,
bins,
counts,
n_bits,
constrained=constrained,
n_iter=n_iter,
min_threshold=min_threshold,
is_signed=is_signed)
return {THRESHOLD: threshold, SIGNED: signed}


def power_of_two_no_clipping_selection_min_max(bins: np.ndarray,
Expand All @@ -151,7 +156,8 @@ def power_of_two_no_clipping_selection_min_max(bins: np.ndarray,
n_iter: int = 20,
min_threshold: float = MIN_THRESHOLD,
quant_error_method: qc.QuantizationErrorMethod =
qc.QuantizationErrorMethod.NOCLIPPING) -> dict:
qc.QuantizationErrorMethod.NOCLIPPING,
is_signed: bool = None) -> Dict:
"""
Gets a threshold between min and max numbers.
If computed threshold is less than min_threshold, min_threshold is returned.
Expand All @@ -168,4 +174,5 @@ def power_of_two_no_clipping_selection_min_max(bins: np.ndarray,
constrained,
n_iter,
min_threshold=min_threshold,
quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING)
quant_error_method=qc.QuantizationErrorMethod.NOCLIPPING,
is_signed=is_signed)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
from typing import Dict
from typing import Dict, Union

from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
Expand All @@ -25,7 +25,7 @@

def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
nodes_prior_info: NodePriorInfo,
out_stats_container: BaseStatsCollector) -> Dict[str, float]:
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
"""
Compute the activations params for a given node in a graph according to a params function.
Expand All @@ -49,7 +49,9 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf
bins_counts)
min_value, max_value = out_stats_container.get_min_max_values()

if nodes_prior_info.is_output_bounded():
if activation_quant_cfg.is_signed is not None:
signed = activation_quant_cfg.is_signed
elif nodes_prior_info.is_output_bounded():
signed = min_value < 0
else:
signed = np.any(bins_values[:-1][bins_counts > 0] < 0)
Expand All @@ -65,14 +67,12 @@ def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConf
activation_quant_cfg.activation_quantization_params_fn = \
quantization_params_generation.uniform_no_clipping_selection_min_max

activation_params = activation_quant_cfg.activation_quantization_params_fn(bins_values,
bins_counts,
activation_quant_cfg.l_p_value,
activation_quant_cfg.activation_n_bits,
min_value,
max_value,
min_threshold=activation_quant_cfg.min_threshold,
quant_error_method=activation_quant_cfg.activation_error_method)
activation_params.update({SIGNED: signed})

return activation_params
return activation_quant_cfg.activation_quantization_params_fn(bins_values,
bins_counts,
activation_quant_cfg.l_p_value,
activation_quant_cfg.activation_n_bits,
min_value,
max_value,
min_threshold=activation_quant_cfg.min_threshold,
quant_error_method=activation_quant_cfg.activation_error_method,
is_signed=signed)
Loading

0 comments on commit 79f5098

Please sign in to comment.