diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index 9833e20cd..b90bc6a87 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -40,6 +40,7 @@ def __init__(self, layer_class: type, reuse: bool = False, reuse_group: str = None, + inputs_as_list: bool = False, quantization_attr: Dict[str, Any] = None, has_activation: bool = True, is_custom: bool = False @@ -58,6 +59,7 @@ def __init__(self, layer_class: Class path of the layer this node represents. reuse: Whether this node was duplicated and represents a reused layer. reuse_group: Name of group of nodes from the same reused layer. + inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer. quantization_attr: Attributes the node holds regarding how it should be quantized. has_activation: Whether the node has activations that we might want to quantize. is_custom: Whether the node is custom layer or not. @@ -71,6 +73,7 @@ def __init__(self, self.layer_class = layer_class self.reuse = reuse self.reuse_group = reuse_group + self.inputs_as_list = inputs_as_list self.final_weights_quantization_cfg = None self.final_activation_quantization_cfg = None self.candidates_quantization_cfg = None diff --git a/model_compression_toolkit/core/common/graph/functional_node.py b/model_compression_toolkit/core/common/graph/functional_node.py index 8743875da..bcf2cc15e 100644 --- a/model_compression_toolkit/core/common/graph/functional_node.py +++ b/model_compression_toolkit/core/common/graph/functional_node.py @@ -55,13 +55,13 @@ def __init__(self, layer_class, reuse, reuse_group, + inputs_as_list, quantization_attr, has_activation=has_activation) self.op_call_kwargs = op_call_kwargs self.op_call_args = list(op_call_args) self.functional_op = functional_op - self.inputs_as_list = inputs_as_list self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs @property diff --git a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py index 4620309fb..920b720f9 100644 --- a/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +++ b/model_compression_toolkit/core/keras/back2framework/keras_model_builder.py @@ -308,7 +308,7 @@ def _run_operation(self, else: # If operator expects a single input tensor, it cannot be a list as it should # have a dtype field. - if len(input_tensors) == 1: + if len(input_tensors) == 1 and not n.inputs_as_list: input_tensors = input_tensors[0] out_tensors_of_n_float = op_func(input_tensors) diff --git a/model_compression_toolkit/core/keras/reader/node_builder.py b/model_compression_toolkit/core/keras/reader/node_builder.py index 84b9686d5..1b5fa0b7f 100644 --- a/model_compression_toolkit/core/keras/reader/node_builder.py +++ b/model_compression_toolkit/core/keras/reader/node_builder.py @@ -30,10 +30,12 @@ from keras.src.layers.core import TFOpLambda, SlicingOpLambda from keras.src.engine.keras_tensor import KerasTensor from keras.src.engine.node import Node as KerasNode + from keras.src.layers.merging.base_merge import _Merge else: from keras.layers.core import TFOpLambda, SlicingOpLambda from keras.engine.keras_tensor import KerasTensor from keras.engine.node import Node as KerasNode + from keras.layers.merging.base_merge import _Merge from model_compression_toolkit.core.common.graph.base_node import BaseNode from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode @@ -287,6 +289,7 @@ def build_node(node: KerasNode, for i, arg in enumerate(op_call_args[0]): if is_const(arg): weights.update({i: to_numpy(arg, is_single_tensor=True)}) + inputs_as_list = __is_node_inputs_a_list(op_call_args, keras_layer) node = BaseNode(node_name, layer_config, @@ -296,6 +299,7 @@ def build_node(node: KerasNode, layer_class, is_reused, reuse_group, + inputs_as_list, is_custom=is_keras_custom_layer(layer_class)) node_name_to_node[node_name] = node @@ -316,6 +320,24 @@ def __is_functional_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool: """ return (keras_layer.symbol in - [TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol,TFOpLambda(tf.add_n).symbol] and + [TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol, TFOpLambda(tf.add_n).symbol] and len(op_call_args) > 0 and isinstance(op_call_args[0], list)) + + +def __is_node_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool: + """ + Check whether the input tensors should be passed as a list or not. This is relevant + only for layers that inherit from _Merge such as Concatenate and Add. + + Args: + op_call_args: Arguments list to check. + keras_layer: Keras layer. + + Returns: + Whether the input tensors should be passed as a list or not. + """ + + return (isinstance(keras_layer, _Merge) and + len(op_call_args) > 0 and + isinstance(op_call_args[0], (list, tuple))) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index af3ed0a6f..3d7e1f32f 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -139,7 +139,11 @@ def _run_operation(n: BaseNode, _tensor_input_allocs = None if isinstance(n, FunctionalNode) and n.inputs_as_list: - out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs) + if isinstance(op_func, PytorchQuantizationWrapper): + # in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper. + out_tensors_of_n_float = op_func(*input_tensors) + else: + out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs) else: merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(), tensor_input_allocs=_tensor_input_allocs) diff --git a/model_compression_toolkit/core/pytorch/reader/graph_builders.py b/model_compression_toolkit/core/pytorch/reader/graph_builders.py index f8dcab6a3..47ea21462 100644 --- a/model_compression_toolkit/core/pytorch/reader/graph_builders.py +++ b/model_compression_toolkit/core/pytorch/reader/graph_builders.py @@ -232,10 +232,19 @@ def nodes_builder(model: GraphModule, # Add constants to weights dictionary. if node.op != PLACEHOLDER: - for i, input_node in enumerate(node.all_input_nodes): - if input_node in consts_dict: - used_consts.add(input_node) - weights.update({i: consts_dict[input_node]}) + if len(node.args) and isinstance(node.args[0], (list, tuple)): + # handle weights in nodes with list input. Especially when there's a duplicate of a tensor + # in the input list (e.g. torch.concat([const1, x, const2, x, const3], 1)). + for input_node in node.all_input_nodes: + for i, input_arg in enumerate(node.args[0]): + if input_node is input_arg and input_node in consts_dict: + used_consts.add(input_node) + weights.update({i: consts_dict[input_node]}) + else: + for i, input_node in enumerate(node.all_input_nodes): + if input_node in consts_dict: + used_consts.add(input_node) + weights.update({i: consts_dict[input_node]}) # Extract input and output shapes of the node. input_shape, output_shape = _extract_input_and_output_shapes(node) diff --git a/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py index 706bf4d97..47e42d1c7 100644 --- a/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Tuple, Callable +from typing import Tuple, Callable, Union from model_compression_toolkit.core import common from model_compression_toolkit.core.common import Graph from model_compression_toolkit.verify_packages import FOUND_TF @@ -25,10 +25,12 @@ import tensorflow as tf from tensorflow.keras.layers import Layer from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder + from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode from mct_quantizers import KerasQuantizationWrapper from mct_quantizers import KerasActivationQuantizationHolder + from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS - def _get_wrapper(node: common.BaseNode, + def _get_wrapper(node: Union[common.BaseNode, FunctionalNode], layer: Layer, fw_impl=None) -> Layer: """ @@ -45,9 +47,16 @@ def _get_wrapper(node: common.BaseNode, # for positional weights we need to extract the weight's value. weights_values = {attr: node.get_weights_by_keys(attr) for attr in weights_quantizers if isinstance(attr, int)} + # When wrapping functional nodes, need to set call args\kwargs in wrapper, because they + # are used during wrapper call method. + func_node_kwargs = {OP_CALL_ARGS: node.op_call_args, + OP_CALL_KWARGS: node.op_call_kwargs + } if isinstance(node, FunctionalNode) else {} return KerasQuantizationWrapper(layer, weights_quantizers, - weights_values) + weights_values, + is_inputs_as_list=node.inputs_as_list, + **func_node_kwargs) return layer diff --git a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py index ea9ba14a6..64621f265 100644 --- a/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +++ b/model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py @@ -24,7 +24,9 @@ if FOUND_TORCH: import torch from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder + from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder + from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode def fully_quantized_wrapper(node: common.BaseNode, @@ -46,7 +48,14 @@ def fully_quantized_wrapper(node: common.BaseNode, # for positional weights we need to extract the weight's value. weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr)) for attr in weight_quantizers if isinstance(attr, int)} - return PytorchQuantizationWrapper(module, weight_quantizers, weights_values) + # When wrapping functional nodes, need to set call args\kwargs in wrapper, because they + # are used during wrapper call method. + func_node_kwargs = {OP_CALL_ARGS: node.op_call_args, + OP_CALL_KWARGS: node.op_call_kwargs + } if isinstance(node, FunctionalNode) else {} + return PytorchQuantizationWrapper(module, weight_quantizers, weights_values, + is_inputs_as_list=node.inputs_as_list, + **func_node_kwargs) return module diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py index 765d4e979..c0ba52620 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py @@ -18,13 +18,14 @@ import model_compression_toolkit as mct from model_compression_toolkit.core import MixedPrecisionQuantizationConfig -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tp_model import generate_tp_model, \ +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import generate_tp_model, \ get_op_quantization_configs -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v3.tpc_keras import generate_keras_tpc +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tpc_keras import generate_keras_tpc from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, DEFAULT_WEIGHT_ATTR_CONFIG, \ generate_test_tp_model, generate_custom_test_tp_model from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest from tests.common_tests.helpers.tensors_compare import cosine_similarity +from tests.keras_tests.utils import get_layers_from_model_by_type from mct_quantizers import KerasQuantizationWrapper, QuantizationMethod from model_compression_toolkit.constants import TENSORFLOW @@ -35,6 +36,39 @@ tp = mct.target_platform +def create_const_quant_tpc(qmethod): + name = "const_quant_tpc" + base_cfg, mp_op_cfg_list, default_cfg = get_op_quantization_configs() + base_tp_model = generate_tp_model(default_config=default_cfg, + base_config=base_cfg, + mixed_precision_cfg_list=mp_op_cfg_list, + name=name) + + const_config = default_cfg.clone_and_edit( + default_weight_attr_config=default_cfg.default_weight_attr_config.clone_and_edit( + enable_weights_quantization=True, weights_per_channel_threshold=True, + weights_n_bits=16, weights_quantization_method=qmethod)) + const_configuration_options = tp.QuantizationConfigOptions([const_config]) + const_merge_config = default_cfg.clone_and_edit( + default_weight_attr_config=default_cfg.default_weight_attr_config.clone_and_edit( + weights_per_channel_threshold=False)) + const_merge_configuration_options = tp.QuantizationConfigOptions([const_merge_config]) + + operator_sets_dict = {} + operator_sets_dict["Add"] = const_configuration_options + operator_sets_dict["Sub"] = const_configuration_options + operator_sets_dict["Mul"] = const_configuration_options + operator_sets_dict["Div"] = const_configuration_options + operator_sets_dict["MergeOps"] = const_merge_configuration_options + + tp_model = generate_custom_test_tp_model(name=name, + base_cfg=base_cfg, + base_tp_model=base_tp_model, + operator_sets_dict=operator_sets_dict) + + return generate_keras_tpc(name="const_quant_tpc", tp_model=tp_model) + + class ConstQuantizationTest(BaseKerasFeatureNetworkTest): def __init__(self, unit_test, layer, const, is_list_input=False, input_reverse_order=False, use_kwargs=False, @@ -58,31 +92,7 @@ def get_quantization_config(self): return mct.core.QuantizationConfig(weights_error_method=self.error_method) def get_tpc(self): - name = "const_quant_tpc" - base_cfg, mp_op_cfg_list, default_cfg = get_op_quantization_configs() - base_tp_model = generate_tp_model(default_config=default_cfg, - base_config=base_cfg, - mixed_precision_cfg_list=mp_op_cfg_list, - name=name) - - const_config = default_cfg.clone_and_edit( - default_weight_attr_config=default_cfg.default_weight_attr_config.clone_and_edit( - enable_weights_quantization=True, weights_per_channel_threshold=True, - weights_quantization_method=self.qmethod)) - const_configuration_options = tp.QuantizationConfigOptions([const_config]) - - operator_sets_dict = {} - operator_sets_dict["Add"] = const_configuration_options - operator_sets_dict["Sub"] = const_configuration_options - operator_sets_dict["Mul"] = const_configuration_options - operator_sets_dict["Div"] = const_configuration_options - - tp_model = generate_custom_test_tp_model(name=name, - base_cfg=base_cfg, - base_tp_model=base_tp_model, - operator_sets_dict=operator_sets_dict) - - return generate_keras_tpc(name="const_quant_tpc", tp_model=tp_model) + return create_const_quant_tpc(self.qmethod) def create_networks(self): inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) @@ -159,3 +169,37 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= msg='TFOpLambda should be quantized') self.unit_test.assertTrue((quantized_model.layers[5].weight_values[1] == self.const).all(), msg='Constant value should not change') + + +class ConstQuantizationMultiInputTest(BaseKerasFeatureNetworkTest): + + def __init__(self, unit_test, input_shape=(32, 32, 16)): + super(ConstQuantizationMultiInputTest, self).__init__(unit_test=unit_test, input_shape=input_shape) + + def get_tpc(self): + return mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v4") + + def create_networks(self): + as_const = lambda v: np.random.random(v.shape.as_list()).astype(np.float32) + inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) + x = layers.Concatenate()([inputs, np.random.random((1, 32, 32, 3)), + inputs, np.random.random((1, 32, 32, 3))]) + x1 = layers.Add()([np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))]) + x2 = layers.Multiply()([x, np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))]) + x3 = tf.add_n([x1, as_const(x), x2]) + x1 = tf.reshape(tf.stack([as_const(x1), x1, as_const(x1)], axis=1), (-1, 3*x1.shape[1], x1.shape[2], x1.shape[3])) + x = tf.concat([x1, x2, as_const(x3), x3], 1) + return tf.keras.models.Model(inputs=inputs, outputs=x) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + y = float_model.predict(input_x) + y_hat = quantized_model.predict(input_x) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + cs = cosine_similarity(y, y_hat) + self.unit_test.assertTrue(np.isclose(cs, 1, atol=1e-2), msg=f'fail cosine similarity check:{cs}') + + # check quantization layers: + for op in [tf.concat, tf.stack, layers.Add, layers.Multiply, layers.Concatenate]: + for qlayer in get_layers_from_model_by_type(quantized_model, op): + self.unit_test.assertTrue(isinstance(qlayer, KerasQuantizationWrapper), + msg=f"{op} should be quantized.") diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py index 29daca9b3..bbefb4bcd 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py @@ -153,12 +153,14 @@ def get_tpc(self): return generate_keras_tpc(name="const_representation_test", tp_model=tp) def create_networks(self): + as_const = lambda v: np.random.random(v.shape.as_list()).astype(np.float32) inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) x = layers.Concatenate()([inputs, np.random.random((1, 32, 32, 3)), inputs, np.random.random((1, 32, 32, 3))]) x1 = layers.Add()([np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))]) x2 = layers.Multiply()([x, np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))]) - x3 = tf.add_n([x1, np.random.random(x.shape.as_list()).astype(np.float32), x2]) - x = tf.concat([x1, x2, np.random.random(x3.shape.as_list()).astype(np.float32), x3], 1) + x3 = tf.add_n([x1, as_const(x), x2]) + x1 = tf.reshape(tf.stack([as_const(x1), x1, as_const(x1)], axis=1), (-1, 3*x1.shape[1], x1.shape[2], x1.shape[3])) + x = tf.concat([x1, x2, as_const(x3), x3], 1) return tf.keras.models.Model(inputs=inputs, outputs=x) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index ec1fffdf3..45c8cef3e 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -147,7 +147,7 @@ ConstRepresentationMultiInputTest, ConstRepresentationMatMulTest, ConstRepresentationListTypeArgsTest from tests.keras_tests.feature_networks_tests.feature_networks.concatination_threshold_update import ConcatThresholdtest from tests.keras_tests.feature_networks_tests.feature_networks.const_quantization_test import ConstQuantizationTest, \ - AdvancedConstQuantizationTest + AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest from tests.keras_tests.feature_networks_tests.feature_networks.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest from tests.keras_tests.feature_networks_tests.feature_networks.sigmoid_mul_substitution_test import SigMulSubstitutionTest @@ -588,6 +588,7 @@ def test_const_quantization(self): ConstQuantizationTest(self, func, 5.1, input_reverse_order=True, qmethod=qmethod, error_method=error_method).run_test() AdvancedConstQuantizationTest(self).run_test() + ConstQuantizationMultiInputTest(self).run_test() def test_const_representation(self): c = (np.ones((16,)) + np.random.random((16,))).astype(np.float32) diff --git a/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py b/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py index 7d8155136..ad865ad73 100644 --- a/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/const_quantization_test.py @@ -21,6 +21,7 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest from tests.common_tests.helpers.tensors_compare import cosine_similarity +from tests.pytorch_tests.utils import get_layers_from_model_by_type from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL from model_compression_toolkit.constants import PYTORCH from mct_quantizers import PytorchQuantizationWrapper @@ -138,3 +139,56 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue((list(m.weight_values.values())[0].detach().cpu().numpy() == self.const).all(), msg=f'Expected PytorchQuantizationWrapper const value to match float const.') + + +class MultiInputConstQuantizationNet(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer('cat_const_1', to_torch_tensor(np.random.randint(-128, 127, size=(1, 16, 1, 32)))) + self.register_buffer('cat_const_2', to_torch_tensor(np.random.randint(-128, 127, size=(1, 16, 3, 32)))) + self.register_buffer('concat_const_1', to_torch_tensor(np.random.randint(-128, 127, size=(1, 16, 36, 4)))) + self.register_buffer('concatenate_const_1', to_torch_tensor(np.random.randint(-128, 127, size=(1, 1, 36, 36)))) + self.register_buffer('concatenate_const_2', to_torch_tensor(np.random.randint(-128, 127, size=(1, 2, 36, 36)))) + self.register_buffer('concatenate_const_3', to_torch_tensor(np.random.randint(-128, 127, size=(1, 3, 36, 36)))) + self.register_buffer('stack_const_1', to_torch_tensor(np.random.randint(-128, 127, size=(1, 39, 36, 36)))) + self.register_buffer('stack_const_2', to_torch_tensor(np.random.randint(-128, 127, size=(1, 39, 36, 36)))) + + def forward(self, x): + x = torch.cat([self.cat_const_1, x, self.cat_const_2], dim=2) + x = torch.concat([self.concat_const_1, x], dim=3) + x = torch.concatenate([self.concatenate_const_1, x, + self.concatenate_const_2, x, + self.concatenate_const_3, self.concatenate_const_1], dim=1) + x = torch.stack([self.stack_const_1, x, self.stack_const_2], dim=1) + x = torch.reshape(x, (1, 3*39, 36, 36)) + return x + + +class ConstQuantizationMultiInputTest(BasePytorchFeatureNetworkTest): + + def __init__(self, unit_test): + super().__init__(unit_test=unit_test, input_shape=(16, 32, 32)) + + def generate_inputs(self): + return [np.random.randint(-128, 127, size=in_shape) for in_shape in self.get_input_shapes()] + + def get_tpc(self): + return mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, "v4") + + def create_networks(self): + return MultiInputConstQuantizationNet() + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + in_torch_tensor = to_torch_tensor(input_x[0]) + set_model(float_model) + y = float_model(in_torch_tensor) + y_hat = quantized_model(in_torch_tensor) + self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') + cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat)) + self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}') + + # check quantization layers: + for op in [torch.cat, torch.concat, torch.concatenate, torch.stack]: + for qlayer in get_layers_from_model_by_type(quantized_model, op): + self.unit_test.assertTrue(isinstance(qlayer, PytorchQuantizationWrapper), + msg=f"{op} should be quantized.") diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index d489825eb..9d299be68 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -103,7 +103,7 @@ ConstRepresentationCodeTest from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from tests.pytorch_tests.model_tests.feature_models.const_quantization_test import ConstQuantizationTest, \ - AdvancedConstQuantizationTest + AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest @@ -263,6 +263,7 @@ def test_const_quantization(self): ConstQuantizationTest(self, func, 5, input_reverse_order=True).run_test() AdvancedConstQuantizationTest(self).run_test() + ConstQuantizationMultiInputTest(self).run_test() def test_const_representation(self): for const_dtype in [np.float32, np.int64, np.int32]: