diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index 03969f460..4f82e92e3 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -36,7 +36,7 @@ def __init__(self, framework_attr: Dict[str, Any], input_shape: Tuple[Any], output_shape: Tuple[Any], - weights: Dict[str, np.ndarray], + weights: Dict[Union[str, int], np.ndarray], layer_class: type, reuse: bool = False, reuse_group: str = None, diff --git a/model_compression_toolkit/core/common/graph/functional_node.py b/model_compression_toolkit/core/common/graph/functional_node.py index 3c2ba83bf..8d494f051 100644 --- a/model_compression_toolkit/core/common/graph/functional_node.py +++ b/model_compression_toolkit/core/common/graph/functional_node.py @@ -59,7 +59,7 @@ def __init__(self, has_activation=has_activation) self.op_call_kwargs = op_call_kwargs - self.op_call_args = op_call_args + self.op_call_args = list(op_call_args) self.functional_op = functional_op self.inputs_as_list = inputs_as_list self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs diff --git a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py index 7f6c80a35..4620309fb 100644 --- a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +++ b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from copy import copy import tensorflow as tf from keras.models import Model @@ -19,6 +20,7 @@ from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder from model_compression_toolkit.core.common.user_info import UserInformation +from model_compression_toolkit.logger import Logger if version.parse(tf.__version__) >= version.parse("2.13"): from keras import Input @@ -271,15 +273,38 @@ def _run_operation(self, out_tensors_of_n_float) else: input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists + if isinstance(n, FunctionalNode): + op_call_kwargs = {} if n.op_call_kwargs is None else copy(n.op_call_kwargs) if not isinstance(op_func, KerasQuantizationWrapper): # The KerasQuantizationWrapper will insert the quantized positional weights internally. - input_tensors = n.insert_positional_weights_to_input_list(input_tensors) + if isinstance(n, FunctionalNode): + if n.tensor_input_allocs is not None: + if n.inputs_as_list: + input_tensors = n.insert_positional_weights_to_input_list(input_tensors) + else: + # If the were any const attributes in the layer's inputs, we retrieve them as kwargs + # for the operator call. + for pos, k in enumerate(n.tensor_input_allocs): + if k not in op_call_kwargs: # op_call_kwargs is initialized because we are under FunctionalNode + # If the argument is saved in tensor_input_allocs but does not exists in the node kwargs + # then it is expected to be either an input tensor or a positional weight of the node. + arg = n.weights.get(pos) + if arg is None: + if len(input_tensors) == 0: + Logger.critical(f"Couldn't find a weight or input tensor matching operator's " + f"argument name '{k}' in location {pos} for node {n.name}.") + arg = input_tensors.pop(0) + op_call_kwargs.update({k: arg}) + else: + # If the operator is not a functional node then positional weights should be inserted + # into the inputs list. + input_tensors = n.insert_positional_weights_to_input_list(input_tensors) # Build a functional node using its args if isinstance(n, FunctionalNode): if n.inputs_as_list: # If the first argument should be a list of tensors: - out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **n.op_call_kwargs) + out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **op_call_kwargs) else: # If the input tensors should not be a list but iterated: - out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **n.op_call_kwargs) + out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **op_call_kwargs) else: # If operator expects a single input tensor, it cannot be a list as it should # have a dtype field. diff --git a/model_compression_toolkit/core/keras/reader/node_builder.py b/model_compression_toolkit/core/keras/reader/node_builder.py index 213b9f32f..27612c7cf 100644 --- a/model_compression_toolkit/core/keras/reader/node_builder.py +++ b/model_compression_toolkit/core/keras/reader/node_builder.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import Any, List, Dict +from copy import copy + +from typing import Any, List, Dict, Union, Tuple import tensorflow as tf from tensorflow.python.util import tf_inspect @@ -41,7 +43,7 @@ REUSED_IDENTIFIER = '_reused_' -is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, float)) +is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, tuple, list)) is_tensor = lambda x: isinstance(x, KerasTensor) @@ -62,35 +64,139 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]: Positional weights are saved according to their index in the node's call arguments, so need to know the function arguments' names in case the weights are in the kwargs. - Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so - it will only include the arguments that may contain constants. For example, we don't - want the transpose_a attribute of tf.matmul to be saved as a constant. - - Every operation we add support to, needs to be added here. - Args: tfoplambda_layer: TFOpLambda layer. Returns: A dictionary with argument number and index: {arg_name: arg_index}. """ - kwargs2index = {tf.add: {'x': 0, 'y': 1}, - tf.subtract: {'x': 0, 'y': 1}, - tf.divide: {'x': 0, 'y': 1}, - tf.truediv: {'x': 0, 'y': 1}, - tf.multiply: {'x': 0, 'y': 1}, - tf.pow: {'x': 0, 'y': 1}, - tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function) - if not kwargs2index: - # In TF 2.15 the function attribute is different and doesn't match the original - # operation object we use. Therefore, we extract kwargs2index with the symbol. - kwargs2index = {'__operators__.add': {'x': 0, 'y': 1}, - 'math.add': {'x': 0, 'y': 1}, - 'math.multiply': {'x': 0, 'y': 1}, - 'linalg.matmul': {'a': 0, 'b': 1}, - 'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {}) - - return kwargs2index + + full_args = tf_inspect.getfullargspec(tfoplambda_layer.function).args + + return {arg_name: i for i, arg_name in enumerate(full_args)} + + +def _extract_const_attrs_from_kwargs(op_call_kwargs: Dict[str, Any], + kwarg2index: Dict[str, int], + weights: Dict[Union[str, int], Any]) -> Dict[str, Any]: + """ + Extract const weights of the layer from the operator's key arguments dictionary. + This function extracts the attributes, updates the nodes weights dictionary and removes them from the original + kwargs mapping. + + Args: + op_call_kwargs: A mapping of the operator key arguments. + kwarg2index: A dictionary with argument number and index: {arg_name: arg_index}. + weights: Node weights mapping. This dictionary is modified by this function. + + Returns: A modified operator key arguments mapping. + + """ + + # read weights from call kwargs + for k, v in op_call_kwargs.items(): + if is_const(v): + # if k in kwarg2index: + weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)}) + + # remove weights and KerasTensors from op_call_kwargs + op_call_kwargs = {k: v for k, v in op_call_kwargs.items() + if not (kwarg2index.get(k) in weights or is_tensor(v))} + + return op_call_kwargs + + +def _build_arguments_alloc(n: KerasNode, inputs_as_list: bool, kwarg2index: Dict[str, int]) -> List: + """ + Builds arguments allocation list. + In Keras, if there is any argument that is a constant, we convert all arguments and inputs to be + considered as op kwargs for simpler reconstruction of the model from the graph later. + Therefore, we build a location list that includes the argument names (keys). + If the input is a list, then we don't need to save the keys, since we can assume that all possible constant + arguments are within the first argument (the list) and are stored by their position in the list. + + Args: + n: fx node. + inputs_as_list: Is node's inputs are a list. + + Returns: + A list of argument allocations in the node's inputs. + + """ + + tensor_input_alloc = [] + op_call_args = list(n.call_args) + if not inputs_as_list: + sorted_kwargs_pos = sorted(kwarg2index.items(), key=lambda x: x[1]) + tensor_input_alloc = [k for k, _ in sorted_kwargs_pos[:len(op_call_args)]] + for k, idx in sorted_kwargs_pos[len(op_call_args):]: + if k in n.call_kwargs: + tensor_input_alloc.append(k) + + return tensor_input_alloc + +def _extract_const_attrs_from_args(op_call_args: List[Any], + op_call_kwargs: Dict[str, Any], + inputs_as_list: bool, + tensor_inputs_alloc: List, + weights: Dict[Union[str, int], Any]) -> Tuple: + """ + Extract const weights of the layer from the operator's arguments list. + This function extracts the attributes, updates the nodes weights dictionary and removes them from the original + arguments list. + + Args: + op_call_args: A list of the operator arguments. + op_call_kwargs: A mapping of key-arguments of the operator. + inputs_as_list: Whether the input of the layer is a list. + tensor_inputs_alloc: Allocation of argument inputs to the operator (if there are const inputs, otherwise None). + weights: Node weights mapping. This dictionary is modified by this function. + + Returns: A modified operator arguments list. + + """ + + move_args_to_kwargs = tensor_inputs_alloc is not None and len(tensor_inputs_alloc) > 0 + + # read weights from call args + for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args): + if is_const(arg): + weights.update({i: to_numpy(arg, is_single_tensor=True)}) + else: + if not inputs_as_list: + if move_args_to_kwargs: + # In this case we move all arguments and inputs to the kwargs + op_call_kwargs.update({tensor_inputs_alloc[i]: arg}) + + # remove weights and KerasTensors from op_call_args + if inputs_as_list: + op_call_args = tuple(op_call_args[1:]) + else: + op_call_args = tuple([a for i, a in enumerate(op_call_args) + if not (i in weights or is_tensor(a) or (move_args_to_kwargs and tensor_inputs_alloc[i] + in op_call_kwargs))]) + + return op_call_args + + +def _has_const_attributes(op_call_args: List, op_call_kwargs: Dict, input_as_list: bool) -> bool: + """ + Returns whether the layer's input include a constant tensor (that we might want to quantize). + + Args: + op_call_args: A list of arguments to the layer. + op_call_kwargs: A dictionary of key-arguments to the layer. + input_as_list: Whether the input to the layer is a list of tensors. + + Returns: True if the input arguments include a constant tensor, False otherwise. + + """ + if input_as_list: + return any([is_const(a) for a in op_call_args[0]]) + const_args = [a for a in op_call_args if is_const(a)] + const_kwargs = [k for k, v in op_call_kwargs.items() if is_const(v)] + + return len(const_args) > 0 or len(const_kwargs) > 0 def build_node(node: KerasNode, @@ -110,8 +216,8 @@ def build_node(node: KerasNode, """ keras_layer = node.layer # get the layer the node represents. layer_config = keras_layer.get_config() # layer configuration to reconstruct it. - op_call_args = node.call_args - op_call_kwargs = node.call_kwargs + op_call_args = copy(node.call_args) + op_call_kwargs = copy(node.call_kwargs) layer_class = type(keras_layer) # class path to instantiating it in back2framework. weights = {v.name: v.numpy() for v in keras_layer.weights} # layer's weights @@ -152,32 +258,14 @@ def build_node(node: KerasNode, if len(weights) > 0: Logger.critical('Functional nodes are not expected to have weights in this framework.') - # read weights from call args - tf_function_symbols = get_tf_function_symbols() - for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args): - if is_const(arg) or ( - keras_layer.symbol in tf_function_symbols and - isinstance(arg, (tuple, list))): - if inputs_as_list or i in kwarg2index.values(): - weights.update({i: to_numpy(arg, is_single_tensor=True)}) - # remove weights and KerasTensors and weights from op_call_args - if inputs_as_list: - op_call_args = tuple(op_call_args[1:]) - else: - op_call_args = tuple([a for i, a in enumerate(op_call_args) - if not (i in weights or is_tensor(a))]) - - # read weights from call kwargs - weight_keys = [] - for k, v in op_call_kwargs.items(): - if is_const(v) or (keras_layer.symbol in tf_function_symbols and - isinstance(v, (tuple, list))): - if k in kwarg2index: - weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)}) - weight_keys.append(k) - # remove weights and KerasTensors and weights from op_call_kwargs - op_call_kwargs = {k: v for k, v in op_call_kwargs.items() - if not (kwarg2index.get(k) in weights or is_tensor(v))} + # Build tensor_input_alloc required for the model builder. All inputs are received as a list in the builder, + # so tensor_input_alloc is used to allocate each input in the correct place in the node's args & kwargs. + tensor_input_alloc = None if not _has_const_attributes(op_call_args, op_call_kwargs, inputs_as_list) \ + else _build_arguments_alloc(node, inputs_as_list, kwarg2index) + + op_call_args = _extract_const_attrs_from_args(op_call_args, op_call_kwargs, inputs_as_list, + tensor_input_alloc, weights) + op_call_kwargs = _extract_const_attrs_from_kwargs(op_call_kwargs, kwarg2index, weights) node = FunctionalNode(node_name, layer_config, @@ -190,7 +278,8 @@ def build_node(node: KerasNode, is_reused, reuse_group, functional_op=keras_layer.function, - inputs_as_list=inputs_as_list) + inputs_as_list=inputs_as_list, + tensor_input_allocs=tensor_input_alloc) else: # Read constant weights from layers such as layers.Add if len(op_call_args) > 0 and isinstance(op_call_args[0], (list, tuple)): diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py index 05216b752..e08a34b3b 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py @@ -17,9 +17,10 @@ import numpy as np import model_compression_toolkit as mct +from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, DEFAULT_WEIGHT_ATTR_CONFIG from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest from tests.common_tests.helpers.tensors_compare import cosine_similarity -from mct_quantizers import KerasQuantizationWrapper +from mct_quantizers import KerasQuantizationWrapper, QuantizationMethod from model_compression_toolkit.constants import TENSORFLOW from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py index afd83fc57..29daca9b3 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py @@ -16,6 +16,9 @@ import numpy as np import model_compression_toolkit as mct +from model_compression_toolkit import get_target_platform_capabilities +from model_compression_toolkit.constants import TENSORFLOW +from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest @@ -69,6 +72,31 @@ def create_networks(self): x = self.layer(x, self.const) return tf.keras.models.Model(inputs=inputs, outputs=x) + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + cs = cosine_similarity(y, y_hat) + self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}, for operator {self.layer}') + + +class ConstRepresentationListTypeArgsTest(BaseKerasFeatureNetworkTest): + + def __init__(self, unit_test, input_shape=(32, 32, 16)): + super(ConstRepresentationListTypeArgsTest, self).__init__(unit_test=unit_test, input_shape=input_shape) + + def generate_inputs(self): + # need positive inputs so won't divide with zero or take root of negative number + return [1 + np.random.random(in_shape) for in_shape in self.get_input_shapes()] + + def get_tpc(self): + return get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL) + + def create_networks(self): + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = tf.image.resize(inputs, size=self.get_input_shapes()[0][1:3]) + return tf.keras.models.Model(inputs=inputs, outputs=x) + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): y = float_model.predict(input_x) y_hat = quantized_model.predict(input_x) diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 8f8b00916..e895cb1e5 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -137,7 +137,7 @@ from tests.keras_tests.feature_networks_tests.feature_networks.metadata_test import MetadataTest from tests.keras_tests.feature_networks_tests.feature_networks.tpc_test import TpcTest from tests.keras_tests.feature_networks_tests.feature_networks.const_representation_test import ConstRepresentationTest, \ - ConstRepresentationMultiInputTest, ConstRepresentationMatMulTest + ConstRepresentationMultiInputTest, ConstRepresentationMatMulTest, ConstRepresentationListTypeArgsTest from tests.keras_tests.feature_networks_tests.feature_networks.concatination_threshold_update import ConcatThresholdtest from tests.keras_tests.feature_networks_tests.feature_networks.const_quantization_test import ConstQuantizationTest, \ AdvancedConstQuantizationTest @@ -560,7 +560,6 @@ def test_const_quantization(self): ConstQuantizationTest(self, func, c, input_reverse_order=True, qmethod=qmethod).run_test() ConstQuantizationTest(self, func, c, input_reverse_order=True, use_kwargs=True, qmethod=qmethod).run_test() ConstQuantizationTest(self, func, c, use_kwargs=True, qmethod=qmethod).run_test() - ConstQuantizationTest(self, func, 2.45, qmethod=qmethod).run_test() ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod).run_test() AdvancedConstQuantizationTest(self).run_test() @@ -586,6 +585,7 @@ def test_const_representation(self): ConstRepresentationTest(self, func, c, use_kwargs=True, is_list_input=True).run_test() ConstRepresentationMultiInputTest(self).run_test() + ConstRepresentationListTypeArgsTest(self).run_test() def test_second_moment(self): DepthwiseConv2DSecondMomentTest(self).run_test()