From 92e8ab04d1b7c2d176f776f57457420ea46bed6b Mon Sep 17 00:00:00 2001 From: Oliver Maier Date: Fri, 8 Mar 2024 09:00:26 +0100 Subject: [PATCH 1/2] Update quantize_wrapper.py trainable weights is only available after calling super.build. This leads to errors when reading models from saved .keras files as number of parameters does not match. This can be solved by changing the order as proposed. In addition, the trainable weights of the layer are expected to be in front of the quantize wrapper weights --- .../python/core/quantization/keras/quantize_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py index 32a6e2dec..6424ddf3c 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py @@ -252,8 +252,8 @@ def losses(self): class QuantizeWrapperV2(QuantizeWrapper): def build(self, input_shape): - self._trainable_weights.extend(self.layer.trainable_weights) super(QuantizeWrapperV2, self).build(input_shape) + self._trainable_weights = self.layer.trainable_weights + self._trainable_weights @property def trainable_weights(self): From 0ada0a06fcb5e092e34c45d49c324bca277b2ede Mon Sep 17 00:00:00 2001 From: Oliver Maier Date: Fri, 8 Mar 2024 09:06:34 +0100 Subject: [PATCH 2/2] Update quantize_layer.py Calling super.build() of the layer to indicate that the layer was built. Previous version did not set the self.built=True flag and would lead to a mismatch between stored and expected weights when loading a .keras model --- .../python/core/quantization/keras/quantize_layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py index 59df68c1c..af71b85b5 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py @@ -54,6 +54,7 @@ def __init__(self, quantizer, **kwargs): self.quantizer = quantizer def build(self, input_shape): + super(QuantizeLayer, self).build(input_shape) if self.quantizer: self.quantizer_vars = self.quantizer.build( input_shape, self.name, self)