Skip to content

Commit

Permalink
Keras node builder and model builder with positional weights refactor (
Browse files Browse the repository at this point in the history
…#1127)

Refactor to the Keras Node builder and cost representation.
Main modifications:

1. Remove hard-coded operator mapping and replace with tf.inspect to extract const weights from TFOpLambda layers.
2. Change back2framework when dealing with positional weights - all inputs are retrieved as kwargs when building the model from the graph.
3. Minor refactor to improve Keras node builder readability.

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon authored Jul 18, 2024
1 parent 71c2549 commit 9d54fe9
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 62 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self,
framework_attr: Dict[str, Any],
input_shape: Tuple[Any],
output_shape: Tuple[Any],
weights: Dict[str, np.ndarray],
weights: Dict[Union[str, int], np.ndarray],
layer_class: type,
reuse: bool = False,
reuse_group: str = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self,
has_activation=has_activation)

self.op_call_kwargs = op_call_kwargs
self.op_call_args = op_call_args
self.op_call_args = list(op_call_args)
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from copy import copy

import tensorflow as tf
from keras.models import Model
from packaging import version

from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.logger import Logger

if version.parse(tf.__version__) >= version.parse("2.13"):
from keras import Input
Expand Down Expand Up @@ -271,15 +273,38 @@ def _run_operation(self,
out_tensors_of_n_float)
else:
input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
if isinstance(n, FunctionalNode):
op_call_kwargs = {} if n.op_call_kwargs is None else copy(n.op_call_kwargs)
if not isinstance(op_func, KerasQuantizationWrapper):
# The KerasQuantizationWrapper will insert the quantized positional weights internally.
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
if isinstance(n, FunctionalNode):
if n.tensor_input_allocs is not None:
if n.inputs_as_list:
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
else:
# If the were any const attributes in the layer's inputs, we retrieve them as kwargs
# for the operator call.
for pos, k in enumerate(n.tensor_input_allocs):
if k not in op_call_kwargs: # op_call_kwargs is initialized because we are under FunctionalNode
# If the argument is saved in tensor_input_allocs but does not exists in the node kwargs
# then it is expected to be either an input tensor or a positional weight of the node.
arg = n.weights.get(pos)
if arg is None:
if len(input_tensors) == 0:
Logger.critical(f"Couldn't find a weight or input tensor matching operator's "
f"argument name '{k}' in location {pos} for node {n.name}.")
arg = input_tensors.pop(0)
op_call_kwargs.update({k: arg})
else:
# If the operator is not a functional node then positional weights should be inserted
# into the inputs list.
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
# Build a functional node using its args
if isinstance(n, FunctionalNode):
if n.inputs_as_list: # If the first argument should be a list of tensors:
out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **n.op_call_kwargs)
out_tensors_of_n_float = op_func(input_tensors, *n.op_call_args, **op_call_kwargs)
else: # If the input tensors should not be a list but iterated:
out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **n.op_call_kwargs)
out_tensors_of_n_float = op_func(*input_tensors, *n.op_call_args, **op_call_kwargs)
else:
# If operator expects a single input tensor, it cannot be a list as it should
# have a dtype field.
Expand Down
197 changes: 143 additions & 54 deletions model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, List, Dict
from copy import copy

from typing import Any, List, Dict, Union, Tuple

import tensorflow as tf
from tensorflow.python.util import tf_inspect
Expand Down Expand Up @@ -41,7 +43,7 @@

REUSED_IDENTIFIER = '_reused_'

is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, float))
is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray, tuple, list))
is_tensor = lambda x: isinstance(x, KerasTensor)


Expand All @@ -62,35 +64,139 @@ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
Positional weights are saved according to their index in the node's call arguments, so
need to know the function arguments' names in case the weights are in the kwargs.
Note: the kwargs2index dictionary is initialized manually (and not with tf_inspect) so
it will only include the arguments that may contain constants. For example, we don't
want the transpose_a attribute of tf.matmul to be saved as a constant.
Every operation we add support to, needs to be added here.
Args:
tfoplambda_layer: TFOpLambda layer.
Returns:
A dictionary with argument number and index: {arg_name: arg_index}.
"""
kwargs2index = {tf.add: {'x': 0, 'y': 1},
tf.subtract: {'x': 0, 'y': 1},
tf.divide: {'x': 0, 'y': 1},
tf.truediv: {'x': 0, 'y': 1},
tf.multiply: {'x': 0, 'y': 1},
tf.pow: {'x': 0, 'y': 1},
tf.matmul: {'a': 0, 'b': 1}}.get(tfoplambda_layer.function)
if not kwargs2index:
# In TF 2.15 the function attribute is different and doesn't match the original
# operation object we use. Therefore, we extract kwargs2index with the symbol.
kwargs2index = {'__operators__.add': {'x': 0, 'y': 1},
'math.add': {'x': 0, 'y': 1},
'math.multiply': {'x': 0, 'y': 1},
'linalg.matmul': {'a': 0, 'b': 1},
'concat': {'values': 0}}.get(tfoplambda_layer.symbol, {})

return kwargs2index

full_args = tf_inspect.getfullargspec(tfoplambda_layer.function).args

return {arg_name: i for i, arg_name in enumerate(full_args)}


def _extract_const_attrs_from_kwargs(op_call_kwargs: Dict[str, Any],
kwarg2index: Dict[str, int],
weights: Dict[Union[str, int], Any]) -> Dict[str, Any]:
"""
Extract const weights of the layer from the operator's key arguments dictionary.
This function extracts the attributes, updates the nodes weights dictionary and removes them from the original
kwargs mapping.
Args:
op_call_kwargs: A mapping of the operator key arguments.
kwarg2index: A dictionary with argument number and index: {arg_name: arg_index}.
weights: Node weights mapping. This dictionary is modified by this function.
Returns: A modified operator key arguments mapping.
"""

# read weights from call kwargs
for k, v in op_call_kwargs.items():
if is_const(v):
# if k in kwarg2index:
weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})

# remove weights and KerasTensors from op_call_kwargs
op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
if not (kwarg2index.get(k) in weights or is_tensor(v))}

return op_call_kwargs


def _build_arguments_alloc(n: KerasNode, inputs_as_list: bool, kwarg2index: Dict[str, int]) -> List:
"""
Builds arguments allocation list.
In Keras, if there is any argument that is a constant, we convert all arguments and inputs to be
considered as op kwargs for simpler reconstruction of the model from the graph later.
Therefore, we build a location list that includes the argument names (keys).
If the input is a list, then we don't need to save the keys, since we can assume that all possible constant
arguments are within the first argument (the list) and are stored by their position in the list.
Args:
n: fx node.
inputs_as_list: Is node's inputs are a list.
Returns:
A list of argument allocations in the node's inputs.
"""

tensor_input_alloc = []
op_call_args = list(n.call_args)
if not inputs_as_list:
sorted_kwargs_pos = sorted(kwarg2index.items(), key=lambda x: x[1])
tensor_input_alloc = [k for k, _ in sorted_kwargs_pos[:len(op_call_args)]]
for k, idx in sorted_kwargs_pos[len(op_call_args):]:
if k in n.call_kwargs:
tensor_input_alloc.append(k)

return tensor_input_alloc

def _extract_const_attrs_from_args(op_call_args: List[Any],
op_call_kwargs: Dict[str, Any],
inputs_as_list: bool,
tensor_inputs_alloc: List,
weights: Dict[Union[str, int], Any]) -> Tuple:
"""
Extract const weights of the layer from the operator's arguments list.
This function extracts the attributes, updates the nodes weights dictionary and removes them from the original
arguments list.
Args:
op_call_args: A list of the operator arguments.
op_call_kwargs: A mapping of key-arguments of the operator.
inputs_as_list: Whether the input of the layer is a list.
tensor_inputs_alloc: Allocation of argument inputs to the operator (if there are const inputs, otherwise None).
weights: Node weights mapping. This dictionary is modified by this function.
Returns: A modified operator arguments list.
"""

move_args_to_kwargs = tensor_inputs_alloc is not None and len(tensor_inputs_alloc) > 0

# read weights from call args
for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
if is_const(arg):
weights.update({i: to_numpy(arg, is_single_tensor=True)})
else:
if not inputs_as_list:
if move_args_to_kwargs:
# In this case we move all arguments and inputs to the kwargs
op_call_kwargs.update({tensor_inputs_alloc[i]: arg})

# remove weights and KerasTensors from op_call_args
if inputs_as_list:
op_call_args = tuple(op_call_args[1:])
else:
op_call_args = tuple([a for i, a in enumerate(op_call_args)
if not (i in weights or is_tensor(a) or (move_args_to_kwargs and tensor_inputs_alloc[i]
in op_call_kwargs))])

return op_call_args


def _has_const_attributes(op_call_args: List, op_call_kwargs: Dict, input_as_list: bool) -> bool:
"""
Returns whether the layer's input include a constant tensor (that we might want to quantize).
Args:
op_call_args: A list of arguments to the layer.
op_call_kwargs: A dictionary of key-arguments to the layer.
input_as_list: Whether the input to the layer is a list of tensors.
Returns: True if the input arguments include a constant tensor, False otherwise.
"""
if input_as_list:
return any([is_const(a) for a in op_call_args[0]])
const_args = [a for a in op_call_args if is_const(a)]
const_kwargs = [k for k, v in op_call_kwargs.items() if is_const(v)]

return len(const_args) > 0 or len(const_kwargs) > 0


def build_node(node: KerasNode,
Expand All @@ -110,8 +216,8 @@ def build_node(node: KerasNode,
"""
keras_layer = node.layer # get the layer the node represents.
layer_config = keras_layer.get_config() # layer configuration to reconstruct it.
op_call_args = node.call_args
op_call_kwargs = node.call_kwargs
op_call_args = copy(node.call_args)
op_call_kwargs = copy(node.call_kwargs)
layer_class = type(keras_layer) # class path to instantiating it in back2framework.
weights = {v.name: v.numpy() for v in keras_layer.weights} # layer's weights

Expand Down Expand Up @@ -152,32 +258,14 @@ def build_node(node: KerasNode,
if len(weights) > 0:
Logger.critical('Functional nodes are not expected to have weights in this framework.')

# read weights from call args
tf_function_symbols = get_tf_function_symbols()
for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
if is_const(arg) or (
keras_layer.symbol in tf_function_symbols and
isinstance(arg, (tuple, list))):
if inputs_as_list or i in kwarg2index.values():
weights.update({i: to_numpy(arg, is_single_tensor=True)})
# remove weights and KerasTensors and weights from op_call_args
if inputs_as_list:
op_call_args = tuple(op_call_args[1:])
else:
op_call_args = tuple([a for i, a in enumerate(op_call_args)
if not (i in weights or is_tensor(a))])

# read weights from call kwargs
weight_keys = []
for k, v in op_call_kwargs.items():
if is_const(v) or (keras_layer.symbol in tf_function_symbols and
isinstance(v, (tuple, list))):
if k in kwarg2index:
weights.update({kwarg2index[k]: to_numpy(v, is_single_tensor=True)})
weight_keys.append(k)
# remove weights and KerasTensors and weights from op_call_kwargs
op_call_kwargs = {k: v for k, v in op_call_kwargs.items()
if not (kwarg2index.get(k) in weights or is_tensor(v))}
# Build tensor_input_alloc required for the model builder. All inputs are received as a list in the builder,
# so tensor_input_alloc is used to allocate each input in the correct place in the node's args & kwargs.
tensor_input_alloc = None if not _has_const_attributes(op_call_args, op_call_kwargs, inputs_as_list) \
else _build_arguments_alloc(node, inputs_as_list, kwarg2index)

op_call_args = _extract_const_attrs_from_args(op_call_args, op_call_kwargs, inputs_as_list,
tensor_input_alloc, weights)
op_call_kwargs = _extract_const_attrs_from_kwargs(op_call_kwargs, kwarg2index, weights)

node = FunctionalNode(node_name,
layer_config,
Expand All @@ -190,7 +278,8 @@ def build_node(node: KerasNode,
is_reused,
reuse_group,
functional_op=keras_layer.function,
inputs_as_list=inputs_as_list)
inputs_as_list=inputs_as_list,
tensor_input_allocs=tensor_input_alloc)
else:
# Read constant weights from layers such as layers.Add
if len(op_call_args) > 0 and isinstance(op_call_args[0], (list, tuple)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import numpy as np

import model_compression_toolkit as mct
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, DEFAULT_WEIGHT_ATTR_CONFIG
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from mct_quantizers import KerasQuantizationWrapper
from mct_quantizers import KerasQuantizationWrapper, QuantizationMethod

from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import numpy as np

import model_compression_toolkit as mct
from model_compression_toolkit import get_target_platform_capabilities
from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
Expand Down Expand Up @@ -69,6 +72,31 @@ def create_networks(self):
x = self.layer(x, self.const)
return tf.keras.models.Model(inputs=inputs, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}, for operator {self.layer}')


class ConstRepresentationListTypeArgsTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test, input_shape=(32, 32, 16)):
super(ConstRepresentationListTypeArgsTest, self).__init__(unit_test=unit_test, input_shape=input_shape)

def generate_inputs(self):
# need positive inputs so won't divide with zero or take root of negative number
return [1 + np.random.random(in_shape) for in_shape in self.get_input_shapes()]

def get_tpc(self):
return get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)

def create_networks(self):
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
x = tf.image.resize(inputs, size=self.get_input_shapes()[0][1:3])
return tf.keras.models.Model(inputs=inputs, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
Expand Down
Loading

0 comments on commit 9d54fe9

Please sign in to comment.