Skip to content

Commit

Permalink
Support multiple weights attributes quantization in core (#964)
Browse files Browse the repository at this point in the history
Adding support in multiple attributes quantization from the MCT core side.

There are three main parts to this change:

* Modifying the node candidate quantization configurations list to include per-attribute candidates: a new class WeightsAttrQuantizationConfig was added and a config of this type is created for each config that appears in the attribute's TPC configs mapping.
* Enabling per-attribute operations.
* Restricting several features for kernel-only: Mixed precision, GPTQ, QAT, and pruning currently continue to work only on the kernel attribute. Several features, like SNC, bias correction, etc. are also relevant only for nodes that have a quantized kernel.

Currently, only kernel quantization is enabled by default.
  • Loading branch information
ofirgo committed Mar 6, 2024
1 parent 36f6262 commit 2e42d5c
Show file tree
Hide file tree
Showing 81 changed files with 1,505 additions and 841 deletions.
16 changes: 16 additions & 0 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,19 @@ def sensitivity_eval_inference(self,
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover

def get_inferable_quantizers(self, node: BaseNode):
"""
Returns sets of framework compatible weights and activation quantizers for the given node.
Args:
node: Node to get quantizers for.
Returns:
weight_quantizers: A dictionary between a weight's name to its quantizer.
activation_quantizers: A list of activations quantization, one for each layer output.
"""

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_inferable_quantizers method.') # pragma: no cover
48 changes: 35 additions & 13 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,52 +529,61 @@ def get_float_memory(self) -> float:
return memory

def get_configurable_sorted_nodes_names(self,
fw_info: FrameworkInfo,
include_reused_nodes: bool = False) -> List[str]:
"""
Get a list of nodes' names that can be configured (namely, has one or
more weight qc candidate). The names are sorted according to the topological
order of the graph.
Args:
fw_info: FrameworkInfo object with information about the specific framework's model.
include_reused_nodes: Whether or not to include reused nodes (False by default).
Returns: List of nodes' names that can be configured (namely, has one or
more weight qc candidate) sorted topology.
"""
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes=include_reused_nodes)]
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(fw_info=fw_info,
include_reused_nodes=include_reused_nodes)]
return sorted_names

def get_weights_configurable_nodes(self,
fw_info: FrameworkInfo,
include_reused_nodes: bool = False) -> List[BaseNode]:
"""
Get a list of nodes that their weights can be configured (namely, has one or
more weight qc candidate and their weights should be quantized).
Args:
fw_info: FrameworkInfo object with information about the specific framework's model.
include_reused_nodes: Whether to include reused nodes (False by default).
Returns:
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
"""
return list(filter(lambda n: n.is_weights_quantization_enabled()
and not n.is_all_weights_candidates_equal()
and (not n.reuse or include_reused_nodes), list(self)))
# configurability is only relevant for kernel attribute quantization
potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)]
return list(filter(lambda n: n.is_weights_quantization_enabled(fw_info.get_kernel_op_attributes(n.type)[0])
and not n.is_all_weights_candidates_equal(fw_info.get_kernel_op_attributes(n.type)[0])
and (not n.reuse or include_reused_nodes), potential_conf_nodes))

def get_sorted_weights_configurable_nodes(self,
fw_info: FrameworkInfo,
include_reused_nodes: bool = False) -> List[BaseNode]:
"""
Get a list of sorted nodes that their weights can be configured (namely, has one or
more weight qc candidate and their weights should be quantized).
Args:
fw_info: FrameworkInfo object with information about the specific framework's model.
include_reused_nodes: Whether to include reused nodes (False by default).
Returns:
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate)
sorted topologically.
"""
return self._sort_nodes_in_list(self.get_weights_configurable_nodes(include_reused_nodes))
return self._sort_nodes_in_list(self.get_weights_configurable_nodes(fw_info, include_reused_nodes))

def get_activation_configurable_nodes(self) -> List[BaseNode]:
"""
Expand All @@ -599,20 +608,22 @@ def get_sorted_activation_configurable_nodes(self) -> List[BaseNode]:
return self._sort_nodes_in_list(self.get_activation_configurable_nodes())

def get_configurable_sorted_nodes(self,
fw_info: FrameworkInfo,
include_reused_nodes: bool = False) -> List[BaseNode]:
"""
Get a list of nodes that can be configured (namely, has one or
more qc candidate and their weights or activations should be quantized).
The nodes are sorted according to the topological order of the graph.
Args:
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
include_reused_nodes: Whether or not to include reused nodes (False by default).
Returns:
A list of nodes that can be configured (namely, has one or more qc candidate) sorted topology.
"""
weights_configurable_nodes = self.get_weights_configurable_nodes(include_reused_nodes)
weights_configurable_nodes = self.get_weights_configurable_nodes(fw_info, include_reused_nodes)
activation_configurable_nodes = self.get_activation_configurable_nodes()

# combine and remove duplications
Expand All @@ -637,51 +648,62 @@ def _sort_nodes_in_list(self, nodes_list: List[BaseNode]) -> List[BaseNode]:
sorted_configurable_nodes.append(n)
return sorted_configurable_nodes

def get_min_candidates_config(self) -> List[int]:
def get_min_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
"""
Builds a minimal configuration.
Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
with minimal n_bits (in both weight and activation if both are quantized, or in the relevant one if only
one of them is quantized)
Args:
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
Returns: A list of candidate for each node (list on indices)
"""

conf_sorted_nodes = self.get_configurable_sorted_nodes()
conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
min_cfg_candidates = [n.find_min_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices

assert all([len(lst) == 1 for lst in min_cfg_candidates]), \
f"A minimal config candidate must be defined, but some node have multiple potential minimal candidates"

return [lst[0] for lst in min_cfg_candidates]

def get_max_candidates_config(self) -> List[int]:
def get_max_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
"""
Builds a maximal configuration.
Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
with maximal n_bits (in both weight and activation if both are quantized, or in the relevant one if only
one of them is quantized)
Args:
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
Returns: A list of candidate for each node (list on indices)
"""

conf_sorted_nodes = self.get_configurable_sorted_nodes()
conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
max_cfg_candidates = [n.find_max_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices

assert all([len(lst) == 1 for lst in max_cfg_candidates]), \
f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates"

return [lst[0] for lst in max_cfg_candidates]

def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]:
def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]:
"""
Gets the final number of bits for quantization of each weights' configurable layer.
Args:
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
Returns: A list of pairs of (node type, node's weights quantization bitwidth).
"""
sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
return [(n, n.final_weights_quantization_cfg.weights_n_bits) for n in sorted_conf_weights]
sorted_conf_weights = self.get_sorted_weights_configurable_nodes(fw_info)
# a configurable node by definition has a kernel op
return [(n, n.final_weights_quantization_cfg.get_attr_config(self.fw_info.get_kernel_op_attributes(n.type)[0]).weights_n_bits)
for n in sorted_conf_weights]

def get_final_activation_config(self) -> List[Tuple[BaseNode, int]]:
"""
Expand Down
Loading

0 comments on commit 2e42d5c

Please sign in to comment.