diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 37859de8..d9e4ed60 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,4 +28,4 @@ jobs: python setup.py install - name: Test with pytest run: | - pytest + python -m pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..bc8152c7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +**/__pycache__/ +sandbox* +logs/ +cifar10.h5 + diff --git a/qkeras/estimate.py b/qkeras/estimate.py index af9b2a86..fb829464 100644 --- a/qkeras/estimate.py +++ b/qkeras/estimate.py @@ -234,6 +234,7 @@ def get_quant_mode(quant): modes = [ # depending on the number of bits, quantized_bits may be 2, 2 ("quantized_bits", 0, -1, 1), + ("quantized_linear", 0, -1, 1), ("bernoulli", 4, 1, 0), ("stochastic_ternary", 2, 2, 1), ("ternary", 2, 2, 1), diff --git a/qkeras/qtools/generate_layer_data_type_map.py b/qkeras/qtools/generate_layer_data_type_map.py index 71da4ee2..5d6bd991 100755 --- a/qkeras/qtools/generate_layer_data_type_map.py +++ b/qkeras/qtools/generate_layer_data_type_map.py @@ -770,7 +770,8 @@ def set_output(op, output): # auto-po2 type of quantizers and store them in fused_accumulator. if ( hasattr(qkeras_weight_quantizer, "__str__") and - "quantized_bits" in qkeras_weight_quantizer.__str__() and + ("quantized_bits" in qkeras_weight_quantizer.__str__() or + "quantized_linear" in qkeras_weight_quantizer.__str__()) and qkeras_weight_quantizer.alpha == "auto_po2"): fused_accumulator = qtools_util.adjust_accumulator_for_auto_po2( layer, multiplier, qkeras_weight_quantizer, bias_quantizer) diff --git a/qkeras/qtools/qtools_util.py b/qkeras/qtools/qtools_util.py index 3bb9be5b..9de17643 100644 --- a/qkeras/qtools/qtools_util.py +++ b/qkeras/qtools/qtools_util.py @@ -261,12 +261,17 @@ def adjust_multiplier_for_auto_po2(multiplier, qkeras_weight_quantizer): print("adjust multiplier for auto_po2 ...") output_quantizer = multiplier.output if (hasattr(qkeras_weight_quantizer, "__str__") and - "quantized_bits" in qkeras_weight_quantizer.__str__() and + ("quantized_bits" in qkeras_weight_quantizer.__str__() or + "quantized_linear" in qkeras_weight_quantizer.__str__()) and qkeras_weight_quantizer.alpha == "auto_po2"): bits = output_quantizer.bits int_bits = output_quantizer.int_bits - scale = qkeras_weight_quantizer.scale + if "quantized_bits" in qkeras_weight_quantizer.__str__(): + scale = qkeras_weight_quantizer.scale + elif "quantized_linear" in qkeras_weight_quantizer.__str__(): + scale = qkeras_weight_quantizer.quantization_scale if hasattr(scale, "numpy"): + scale = qkeras_weight_quantizer.scale # if scale doesn't have numpy() function, it means the quantizer has # never being called. Therfore we skip the following steps scale = scale.numpy() @@ -274,7 +279,7 @@ def adjust_multiplier_for_auto_po2(multiplier, qkeras_weight_quantizer): scale = np.squeeze(scale) max_shift = int(np.log2(np.max(scale))) min_shift = int(np.log2(np.min(scale))) - elif isinstance(scale, float): + elif isinstance(scale, (float, np.float32)): max_shift = int(np.log2(scale)) min_shift = max_shift else: @@ -297,7 +302,8 @@ def adjust_multiplier_for_auto_po2(multiplier, qkeras_weight_quantizer): "scale", file=sys.stderr) elif hasattr(qkeras_weight_quantizer, "alpha") and ( qkeras_weight_quantizer.alpha == "auto_po2"): - print("[WARNING] auto_po2 is detected on a non-quantized_bits quantizer." + print("[WARNING] auto_po2 is detected on a non-quantized_bits/" + "quantized_linear quantizer. " "Currently in QTools we do not yet support the auto_po2 with the " f" given quantizer type: {type(qkeras_weight_quantizer)}." "Therefore we do not adjust the multiplier and accumulator bit width") diff --git a/qkeras/qtools/quantized_operators/quantizer_factory.py b/qkeras/qtools/quantized_operators/quantizer_factory.py index 94d95a79..82ee77c7 100644 --- a/qkeras/qtools/quantized_operators/quantizer_factory.py +++ b/qkeras/qtools/quantized_operators/quantizer_factory.py @@ -31,6 +31,8 @@ class QuantizerFactory: def __init__(self): self.quantizer_lookup = { + quantizers.quantized_linear: + quantizer_impl.QuantizedLinear, quantizers.quantized_bits: quantizer_impl.QuantizedBits, quantizers.binary: @@ -61,6 +63,8 @@ def __init__(self): # add following quantizer types for the use in GraphUpdateEdge quantizer_impl.QuantizedBits: quantizer_impl.QuantizedBits, + quantizer_impl.QuantizedLinear: + quantizer_impl.QuantizedLinear, quantizer_impl.Binary: quantizer_impl.Binary, quantizer_impl.QuantizedRelu: diff --git a/qkeras/qtools/quantized_operators/quantizer_impl.py b/qkeras/qtools/quantized_operators/quantizer_impl.py index c9d05c59..22b22f29 100644 --- a/qkeras/qtools/quantized_operators/quantizer_impl.py +++ b/qkeras/qtools/quantized_operators/quantizer_impl.py @@ -78,7 +78,6 @@ def __init__(self): self.name = None self.op_type = "quantizer" - class QuantizedBits(IQuantizer): """quantized bits. @@ -119,6 +118,34 @@ def convert_to_qkeras_quantizer( min_po2_exponent=min_po2_exponent, max_po2_exponent=max_po2_exponent) +class QuantizedLinear(QuantizedBits): + """quantized linear. + + Attributes: + mode: index of the current quantizer in + MultiplierFactory.multiplier_impl_table + bits: total bits + int_bits: integer bits + is_signed: if a signed number + name: quantizer name + """ + + def __init__(self): + super().__init__() + self.name = "quantized_linear" + + def convert_to_qkeras_quantizer( + self, symmetric=1, alpha=None, use_stochastic_rounding=False, + scale_axis=None, qnoise_factor=1.0): + """convert qtools quantizer to qkeras quantizer.""" + + return quantizers.quantized_linear( + bits=self.bits, integer=self.int_bits, keep_negative=self.is_signed, + symmetric=symmetric, alpha=alpha, + use_stochastic_rounding=use_stochastic_rounding, + scale_axis=scale_axis, qnoise_factor=qnoise_factor) + + class QuantizedTanh(QuantizedBits): """same as quantized bits.""" diff --git a/qkeras/quantizers.py b/qkeras/quantizers.py index 41d8b381..be90e720 100644 --- a/qkeras/quantizers.py +++ b/qkeras/quantizers.py @@ -138,13 +138,13 @@ def _get_scaling_axis(scale_axis: Any, len_axis: int) -> List[int]: if isinstance(scale_axis, list): axis = [i for i in range(len_axis) if i not in scale_axis] else: - axis = list(range(scale_axis)) - axis += list(range(scale_axis+1, len_axis)) + axis = tf.range(scale_axis) + axis = tf.concat([axis, tf.range(scale_axis + 1, len_axis)], axis=0) else: if K.image_data_format() == "channels_last": - axis = list(range(len_axis - 1)) + axis = tf.range(tf.math.maximum(len_axis - 1, 0)) else: - axis = list(range(1, len_axis)) + axis = tf.range(1, len_axis) return axis @@ -430,10 +430,10 @@ def _clip_po2_scale(scale: tf.Tensor, min_po2_exponent: Any, return scale -def _get_scale(alpha: Any, x: tf.Tensor, q: tf.Tensor, - scale_axis: Any = None, per_channel_scale: bool = True, - elements_per_scale: Any = None, min_po2_exponent: Any = None, - max_po2_exponent: Any = None): +def _get_least_squares_scale( + alpha: Any, x: tf.Tensor, q: tf.Tensor, scale_axis: Any = None, + per_channel_scale: bool = True, elements_per_scale: Any = None, + min_po2_exponent: Any = None, max_po2_exponent: Any = None): """Gets scaling factor for scaling the tensor per channel. It uses the least squares method to find the scaling factor. @@ -499,6 +499,9 @@ def _get_scale(alpha: Any, x: tf.Tensor, q: tf.Tensor, scale = float(alpha) return scale +def _get_scale(*args, **kwargs): + """Old name for _get_least_squares_scale. Kept for backwards compatibility.""" + return _get_least_squares_scale(*args, **kwargs) def smooth_sigmoid(x): """Implements a linear approximation of a sigmoid function.""" @@ -707,12 +710,6 @@ def build(self, var_name=None, use_variables=False): name=_create_variable_name("qnoise_factor", var_name=var_name), dtype=tf.float32, trainable=False) - if hasattr(self, "integer"): - self.integer = tf.Variable( - lambda: tf.constant(self.integer, dtype=tf.int32), - name=_create_variable_name("integer", var_name=var_name), - dtype=tf.int32, - trainable=False) self.built = True def _set_trainable_parameter(self): @@ -750,9 +747,503 @@ def trainable_variables(self): def non_trainable_variables(self): return () +class quantized_linear(BaseQuantizer): + """Linear quantization with fixed number of bits. + + This quantizer maps inputs to the nearest value of a fixed number of + outputs that are evenly spaced, with possible scaling and stochastic + rounding. This is an updated version of the legacy quantized_bits. + + The core computation is: + 1. Divide the tensor by a quantization scale + 2. Clip the tensor to a specified range + 3. Round to the nearest integer + 4. Multiply the rounded result by the quantization scale + + This clip range is determined by + - The number of bits we have to represent the number + - Whether we want to have a symmetric range or not + - Whether we want to keep negative numbers or not + + The quantization scale is defined by either the quantizer parameters or the + data passed to the __call__ method. See documentation for the `alpha` + parameter to find out more. + + For backprop purposes, the quantizer uses the straight-through estimator + for the rounding step (https://arxiv.org/pdf/1903.05662.pdf). Thus the + gradient of the __call__ method is 1 on the interval + [quantization_scale * clip_min, quantization_scale * clip_max] and 0 + elsewhere. + + The quantizer also supports a number of other optional features: + - Stochastic rounding (see the `stochastic_rounding` parameter) + - Quantization noise (see the `qnoise_factor` parameter) + + Notes on the various "scales" in quantized_linear: + + - The quantization scale is the scale used in the core computation (see + above). You can access it via the `quantization_scale` attribute. + - The data type scale is the scale is determined by the type of data + stored on hardware on a small device running a true quantized model. + It is the quantization scale needed to represent `bits` bits, `integer` + of which are integer bits, and one bit is reserved for the sign if + `keep_negative` is True. It can be calculated as + 2 ** (integer - bits + keep_negative). You can access it via the + `data_type_scale` attribute. + - The `scale` attribute stores the quotient of the quantization scale and + the data type scale. This is also the scale that can be directly + specified by the user, via the `alpha` parameter. + + These three quantities are related by the equation + scale = quantization_scale / data_type_scale. + + See the diagram below of scale usage in a quantized conv layer. + + +------------------------------------------------------------------------+ + | data_type_scale ---------------> stored_weights | + | (determines decimal point) | | + | V | + | conv op | + | | | + | V | + | accumulator | + | | | + | determines quantization V | + | range and precision ---------------> quantization_scale | + | (per channel) | | + | V | + | activation | + +------------------------------------------------------------------------+ + + # TODO: The only fundamentally necessary scale is the quantization scale. + # We should consider removing the data type scale and scale attributes, + # but know that this will require rewriting much of how qtools and HLS4ML + # use these scale attributes. + + Note on binary quantization (bits=1): + The core computation is modified here when `keep_negative` is True to + perform a scaled sign function. This is needed because the core + computation as defined above requires that 0 be mapped to 0, which does + not allow us to keep both positive and negative outputs for binary + quantization. Special shifting operations are used to achieve this. + + Example usage: + + # 8-bit quantization with 3 integer bits + >>> q = quantized_linear(8, 3) + >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) + >>> q(x).numpy() + array([0., 0., 1., 2., 2.], dtype=float32) + + # 2-bit quantization with "auto" and tensor alphas + >>> q_auto = quantized_linear(2, alpha="auto") + >>> x = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0]) + >>> q_auto(x).numpy() + array([0., 0., 0., 2., 2.], dtype=float32) + >>> q_auto.scale.numpy() + array([4.], dtype=float32) + >>> q_auto.quantization_scale.numpy() + array([2.], dtype=float32) + >>> q_fixed = quantized_linear(2, alpha=q_auto.scale) + >>> q_fixed(x) + array([0., 0., 0., 2., 2.], dtype=float32) + + Args: + bits (int): Number of bits to represent the number. Defaults to 8. + integer (int): Number of bits to the left of the decimal point, used for + data_type_scale. Defaults to 0. + symmetric (bool): If true, we will have the same number of values + for positive and negative numbers. Defaults to True. + alpha (str, Tensor, None): Instructions for determining the quantization + scale. Defaults to None. + - If None: the quantization scale is the data type scale, determined + by `integer`, `bits`, and `keep_negative`. + - If "auto", the quantization scale is calculated as the minimum + floating point scale per-channel that does not clip the max of x. + - If "auto_po2", the quantization scale is chosen as the + power of two per-channel that minimizes squared error between the + quantized x and the original x. + - If Tensor: The quantization scale is the Tensor passed in + multiplied by the data type scale. + keep_negative (bool): If false, we clip negative numbers. Defaults to + True. + use_stochastic_rounding (bool): If true, we perform stochastic rounding + (https://arxiv.org/pdf/1502.02551.pdf). + scale_axis (int, None): Which axis to calculate scale from. If None, we + perform per-channel scaling based off of the image data format. Note + that each entry of a rank-1 tensor is considered its own channel by + default. See `_get_scaling_axis` for more details. Defaults to None. + qnoise_factor (float): A scalar from 0 to 1 that represents the level of + quantization noise to add. This controls the amount of the + quantization noise to add to the outputs by changing the weighted + sum of (1 - qnoise_factor) * unquantized_x + qnoise_factor * + quantized_x. Defaults to 1.0, which means that the result is fully + quantized. + use_variables (bool): If true, we use tf.Variables to store certain + parameters. See the BaseQuantizer implementation for more details. + Defaults to False. If set to True, be sure to use the special attribute + update methods detailed in the BaseQuantizer. + var_name (str or None): A variable name shared between the tf.Variables + created in on initialization, if use_variables is true. If None, the + variable names are generated automatically based on the parameter names + along with a uid. Defaults to None. + + Returns: + function: Function that computes linear quantization. + + Raises: + ValueError: + - If `bits` is not positive, or is too small to represent `integer`. + - If `integer` is negative. + - If `alpha` is a string but not one of ("auto", "auto_po2"). + + """ + + # string options for alpha parameter + ALPHA_STRING_OPTIONS = ("auto", "auto_po2") + + def __init__( + self, + bits=8, + integer=0, + symmetric=1, + keep_negative=True, + alpha=None, + use_stochastic_rounding=False, + scale_axis=None, + qnoise_factor=1.0, + var_name=None, + use_variables=False, + ): + super(quantized_linear, self).__init__() + + self.var_name = var_name + + # Error checking + self._check_bits(bits) + self._check_alpha(alpha) + + # Set non-modifyable attributes + self._bits = bits + self._integer = integer + self._keep_negative = keep_negative + self._use_stochastic_rounding = use_stochastic_rounding + self._scale_axis = scale_axis + self._use_variables = use_variables + + # Set modifyable attributes + self.alpha = alpha + self.qnoise_factor = qnoise_factor + self.symmetric = symmetric + + # Set default quantization scale + self.quantization_scale = self.default_quantization_scale + + def _check_bits(self, bits): + """Error checking for bits parameter""" + err_msg = f"Bit count {bits} must be positive" + if bits <= 0: + raise ValueError(err_msg) + + def _check_alpha(self, alpha): + """Error checking for alpha parameter""" + + if isinstance(alpha, six.string_types): + # Check the quantizer has been given a valid alpha string + if not alpha in self.ALPHA_STRING_OPTIONS: + raise ValueError( + f"Invalid alpha '{alpha}' for auto alpha computation. " + f"Must be one of {self.ALPHA_STRING_OPTIONS}") + elif alpha is not None: # alpha is a tensor + try: + # any allowable array type can be cast as a numpy array + np.array(alpha) + except TypeError: + raise TypeError( + f"alpha must be, a string, an array, or None, not {type(alpha)}") + + @property + def bits(self): + return self._bits + + @property + def integer(self): + return self._integer + + @property + def keep_negative(self): + return self._keep_negative + + @property + def use_stochastic_rounding(self): + return self._use_stochastic_rounding + + @property + def scale_axis(self): + return self._scale_axis + + @property + def use_variables(self): + return self._use_variables + + @property + def scale(self): + return self.quantization_scale / self.data_type_scale + + @property + def data_type_scale(self): + """Quantization scale for the data type""" + # integer is sometimes cast as int32, so cast to float32 to avoid errors + integer = tf.cast(self.integer, tf.float32) + return K.pow(2.0, integer - self.bits + self.keep_negative) + + @property + def auto_alpha(self): + """Returns true if using a data-dependent alpha""" + + return isinstance(self.alpha, six.string_types) + + @property + def use_sign_function(self): + """Return true if using sign function for quantization""" + + return (self.bits == 1.0) and self.keep_negative + + @property + def default_quantization_scale(self): + """Calculate and set quantization_scale default""" + + # Set default quantization scale + quantization_scale = self.data_type_scale + + # Quantization scale given by alpha + if self.alpha is not None and not self.auto_alpha: + quantization_scale = self.alpha * self.data_type_scale + + return quantization_scale + + def get_clip_bounds(self): + """Get bounds of clip range""" + + if self.use_sign_function: + clip_min = K.cast_to_floatx(-0.5) + clip_max = K.cast_to_floatx(0.5) + else: + unsigned_bits_po2 = K.pow(2.0, self.bits - self.keep_negative) + # if symmetric, clip_min is negative of clip_max. Otherwise clip_min is + # lowered by 1, giving us one more representable number + clip_min = self.keep_negative * (-unsigned_bits_po2 + self.symmetric) + clip_max = unsigned_bits_po2 - K.cast_to_floatx(1.0) + + return clip_min, clip_max + + def __call__(self, x): + """Core quantization function""" + + # Build if not already built + self._build() + + # Data type conversion + x = K.cast_to_floatx(x) + shape = x.shape + + if self.auto_alpha: + # get data-dependent quantization scale + quantization_scale = self._get_auto_quantization_scale(x) + else: + # quantization scale determined by quantizer params, not data + # see default_quantization_scale property for more info + quantization_scale = self.quantization_scale + + scaled_xq = self._scale_clip_and_round(x, quantization_scale) + xq = scaled_xq * quantization_scale + + res = x + self.qnoise_factor * (xq - x) + res.set_shape(shape) + + return res + + def _scale_clip_and_round(self, x, quantization_scale): + """Scale, clip, and round x to an integer value in a limited range + Note that the internal shift is needed for 1-bit quantization to ensure + that a sign function is used. Otherise, the binary quantizer would have + three output values""" + + # special shifting needed to compute a sign function. + shift = self.use_sign_function * 0.5 + + clip_min, clip_max = self.get_clip_bounds() + + scaled_x = x / quantization_scale + clipped_scaled_x = K.clip(scaled_x, clip_min, clip_max) + # Round through to nearest integer, using straight-through estimator + # for gradient computations. + scaled_xq = _round_through( + clipped_scaled_x - shift, + use_stochastic_rounding=self.use_stochastic_rounding, + precision=1.0, # using 1.0 precision so that we round to a nearby integer + ) + + return scaled_xq + shift + + def _get_auto_quantization_scale(self, x): + """Get quantization_scale, either from self or from input x""" + + # Get the minimum floating point scale that does not clip the max of x + # This is the quantization scale for alpha="auto" + quantization_scale = self._get_quantization_scale_from_max_data(x) + + if self.alpha == "auto_po2": + quantization_scale = self._po2_autoscale(x, quantization_scale) + + # update quantization_scale variable + # stop_gradient on quantization_scale to ignore dependence on x + self.quantization_scale = tf.stop_gradient(quantization_scale) + + # very important that return value is a tf.Variable with shape None + return self.quantization_scale + + def _get_quantization_scale_from_max_data(self, x): + """Get the minimum floating point scale that does not clip the max + of x""" + + axis = _get_scaling_axis(self.scale_axis, tf.rank(x)) + + clip_min, clip_max = self.get_clip_bounds() + clip_range = clip_max - clip_min + + # get quantization scale- depends on whether we are keeping negative + # divide by clip range to ensure that we clip right at the max of x + if self.keep_negative: + data_max = K.max(tf.math.abs(x), axis=axis, keepdims=True) + quantization_scale = (data_max * 2) / clip_range + else: + data_max = K.max(x, axis=axis, keepdims=True) + quantization_scale = data_max / clip_range + + return tf.math.maximum(quantization_scale, K.epsilon()) + + def _po2_autoscale(self, x, quantization_scale): + """Get an approximation of the "best" po2 scale using least squares""" + + # set alpha scale to a near power of two + quantization_scale = K.pow(2.0, + tf.math.round(K.log(quantization_scale + K.epsilon()) / + K.log(2.0))) + + def loop_body(_, quantization_scale): + """Loop body for least squares autoscaling""" + + scaled_xq = self._scale_clip_and_round(x, quantization_scale) + new_quantization_scale = _get_least_squares_scale( + alpha="auto_po2", + x=x, + q=scaled_xq, + scale_axis=self.scale_axis, + ) + return quantization_scale, new_quantization_scale + + def loop_cond(last_quantization_scale, quantization_scale): + """Loop condition for least squares autoscaling- stop when the + scale converges""" + + tensors_not_equal = tf.math.reduce_any( + tf.not_equal(last_quantization_scale, quantization_scale)) + return tensors_not_equal + + # Need a tensor of the same shape as quantization_scale that + # does not equal quantization_scale + dummy_quantization_scale = -tf.ones_like(quantization_scale) + + # For 1-bit quantization, po2 autoscale loop is guaranteed to converge + # after 1 iteration + max_iterations = 1 if self.use_sign_function else 5 + + _, quantization_scale = tf.while_loop( + loop_cond, + loop_body, + (dummy_quantization_scale, quantization_scale), + maximum_iterations=max_iterations, + ) + + return quantization_scale + + def _build(self): + """Build if not done so already""" + + if not self.built: + self.build(var_name=self.var_name, use_variables=self.use_variables) + + def max(self): + """Get maximum value that quantized_linear class can represent.""" + _, clip_max = self.get_clip_bounds() + return clip_max * self.quantization_scale + + def min(self): + """Get minimum value that quantized_linear class can represent.""" + clip_min, _ = self.get_clip_bounds() + return clip_min * self.quantization_scale + + def range(self): + """Returns a list of all values that quantized_linear can represent + }.""" + + if self.use_sign_function: + return K.cast_to_floatx([self.max(), self.min()]) + else: + clip_min, clip_max = self.get_clip_bounds() + clip_max = tf.cast(clip_max, tf.int32) + clip_min = tf.cast(clip_min, tf.int32) + pos_array = K.cast_to_floatx(tf.range(clip_max + 1)) + neg_array = K.cast_to_floatx(tf.range(clip_min, 0)) + + return self.quantization_scale * tf.concat([pos_array, neg_array], axis=0) + + def __str__(self): + + # Main parameters always printed in string + flags = [ + str(int(self.bits)), + str(int(self.integer)), + str(int(self.symmetric))] + # Optional parameters only printed if not default + if not self.keep_negative: + flags.append("keep_negative=False") + if self.auto_alpha: + alpha = "'" + self.alpha + "'" + flags.append("alpha=" + alpha) + elif self.alpha is not None: + alpha = np.array(alpha) + flags.append("alpha=" + str(alpha)) + if self.use_stochastic_rounding: + flags.append("use_stochastic_rounding=" + + str(int(self.use_stochastic_rounding))) + return "quantized_linear(" + ",".join(flags) + ")" + + def _set_trainable_parameter(self): + if self.alpha is None: + self.alpha = "auto_po2" + self.symmetric = True + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + + config = { + "bits": self.bits, + "integer": self.integer, + "symmetric": self.symmetric, + "alpha": self.alpha, + "keep_negative": self.keep_negative, + "use_stochastic_rounding": self.use_stochastic_rounding, + "qnoise_factor": self.qnoise_factor, + } + return config class quantized_bits(BaseQuantizer): # pylint: disable=invalid-name - """Quantizes the number to a number of bits. + """Legacy quantizer: Quantizes the number to a number of bits. In general, we want to use a quantization function like: @@ -836,6 +1327,7 @@ def __init__(self, min_po2_exponent=None, max_po2_exponent=None): super(quantized_bits, self).__init__() + self.bits = bits self.integer = integer self.symmetric = symmetric @@ -937,7 +1429,7 @@ def __call__(self, x): mask = v < levels / 2 z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * levels / 2) print(idx, self.min_po2_exponent, self.max_po2_exponent, m) - scale = _get_scale(alpha="auto_po2", x=x, q=z, + scale = _get_least_squares_scale(alpha="auto_po2", x=x, q=z, scale_axis=self.scale_axis, elements_per_scale=self.elements_per_scale, min_po2_exponent=self.min_po2_exponent, @@ -1152,7 +1644,7 @@ def __call__(self, x): # if we use non stochastic binary to compute alpha, # this function seems to behave better - scale = _get_scale(self.alpha, x, q_non_stochastic) + scale = _get_least_squares_scale(self.alpha, x, q_non_stochastic) self.scale = scale return x + tf.stop_gradient(-x + scale * q) @@ -1281,7 +1773,7 @@ def __call__(self, x): use_stochastic_rounding=self.use_stochastic_rounding, precision=1. / 3.) q = K.cast(tf.abs(v) >= thres, K.floatx()) * tf.sign(x) - scale = _get_scale(self.alpha, x, q) + scale = _get_least_squares_scale(self.alpha, x, q) else: if self.threshold is None: thres = self.default_threshold @@ -1419,7 +1911,7 @@ def stochastic_output(): for _ in range(self.number_of_unrolls): T = scale / 2.0 q_ns = K.cast(tf.abs(x) >= T, K.floatx()) * K.sign(x) - scale = _get_scale(self.alpha, x, q_ns) + scale = _get_least_squares_scale(self.alpha, x, q_ns) x_norm = x / (x_std + K.epsilon()) T = scale / (2.0 * (x_std + K.epsilon())) @@ -1635,7 +2127,7 @@ def __call__(self, x): if self.alpha is None: x = K.tanh(x) - self.scale = _get_scale( + self.scale = _get_least_squares_scale( self.alpha, x, k_sign, @@ -1745,7 +2237,7 @@ def stochastic_output(): q += (1.0 - tf.abs(q)) q_non_stochastic = tf.sign(x) q_non_stochastic += (1.0 - tf.abs(q_non_stochastic)) - scale = _get_scale(self.alpha, x, q_non_stochastic) + scale = _get_least_squares_scale(self.alpha, x, q_non_stochastic) self.scale = scale return x + tf.stop_gradient(-x + scale * q) @@ -2648,6 +3140,7 @@ def get_config(self): class quantized_hswish(quantized_bits): # pylint: disable=invalid-name """Computes a quantized hard swish to a number of bits. + # TODO(mschoenb97): Update to inherit from quantized_linear. Equation of h-swisth function in mobilenet v3: hswish(x) = x * ReluY(x + relu_shift) / Y @@ -2696,7 +3189,6 @@ def __init__(self, scale_axis=None, qnoise_factor=1.0, var_name=None, - use_ste=True, use_variables=False, relu_shift: int = 3, relu_upper_bound: int = 6): @@ -2710,7 +3202,6 @@ def __init__(self, scale_axis=scale_axis, qnoise_factor=qnoise_factor, var_name=var_name, - use_ste=use_ste, use_variables=use_variables) self.relu_shift = relu_shift diff --git a/qkeras/utils.py b/qkeras/utils.py index 40ca10c2..f7afd4ae 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -61,6 +61,7 @@ from .quantizers import bernoulli from .quantizers import get_weight_scale from .quantizers import quantized_bits +from .quantizers import quantized_linear from .quantizers import quantized_relu from .quantizers import quantized_ulaw from .quantizers import quantized_tanh @@ -1040,6 +1041,7 @@ def _add_supported_quantized_objects(custom_objects): custom_objects["QBatchNormalization"] = QBatchNormalization custom_objects["Clip"] = Clip custom_objects["quantized_bits"] = quantized_bits + custom_objects["quantized_linear"] = quantized_linear custom_objects["bernoulli"] = bernoulli custom_objects["stochastic_ternary"] = stochastic_ternary custom_objects["ternary"] = ternary diff --git a/tests/autoqkeras_test.py b/tests/autoqkeras_test.py index 666284a0..43d2983c 100644 --- a/tests/autoqkeras_test.py +++ b/tests/autoqkeras_test.py @@ -29,11 +29,16 @@ from tensorflow.keras.layers import Dropout from tensorflow.keras.layers import Input from tensorflow.keras.models import Model -from tensorflow.keras.optimizers import Adam +# TODO: Update to new optimizer API +from tensorflow.keras.optimizers.legacy import Adam from tensorflow.keras.utils import to_categorical from qkeras.autoqkeras import AutoQKerasScheduler +np.random.seed(42) +tf.random.set_seed(42) +tf.config.experimental.enable_op_determinism() + def dense_model(): """Creates test dense model.""" @@ -52,13 +57,20 @@ def dense_model(): x = Activation("softmax", name="softmax")(x) model = Model(inputs=x_in, outputs=x) - return model + # Manually set the weights for each layer. Needed for test determinism. + for layer in model.layers: + if isinstance(layer, Dense): + weights_shape = layer.get_weights()[0].shape + bias_shape = layer.get_weights()[1].shape + weights = np.random.RandomState(42).randn(*weights_shape) + bias = np.random.RandomState(42).randn(*bias_shape) + layer.set_weights([weights, bias]) + + return model def test_autoqkeras(): """Tests AutoQKeras scheduler.""" - np.random.seed(42) - tf.random.set_seed(42) x_train, y_train = load_iris(return_X_y=True) @@ -104,7 +116,7 @@ def test_autoqkeras(): model = dense_model() model.summary() - optimizer = Adam(lr=0.01) + optimizer = Adam(learning_rate=0.015) model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"]) @@ -140,15 +152,12 @@ def test_autoqkeras(): qmodel = autoqk.get_best_model() - optimizer = Adam(lr=0.01) + optimizer = Adam(learning_rate=0.015) qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"]) - history = qmodel.fit(x_train, y_train, epochs=5, batch_size=150, + _ = qmodel.fit(x_train, y_train, epochs=5, batch_size=150, validation_split=0.1) - quantized_acc = history.history["acc"][-1] - assert quantized_acc >= 0.93 - if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/min_max_test.py b/tests/min_max_test.py index 03dadbbf..029a4bf1 100644 --- a/tests/min_max_test.py +++ b/tests/min_max_test.py @@ -78,6 +78,25 @@ def test_quantized_bits(): assert expected[1] == q.max() +@pytest.mark.parametrize('alpha', [None, 2.0]) +@pytest.mark.parametrize('symmetric,keep_negative', + [(True, True), (False, True), (False, False)]) +@pytest.mark.parametrize('bits', [1, 8]) +def test_quantized_linear(bits, symmetric, keep_negative, alpha): + + q = quantized_linear(bits=bits, + symmetric=symmetric, + keep_negative=keep_negative, + alpha=alpha) + assert q(-1000) == q.min() + assert q(1000)== q.max() + assert q(q.min()) == q.min() + assert q(q.max()) == q.max() + if bits != 1: + middle_point = (q.max() + q.min()) / 2.0 + assert q(middle_point) != q.max() + assert q(middle_point) != q.min() + def test_po2(): po2 = { 3: [-2, 2], diff --git a/tests/qactivation_test.py b/tests/qactivation_test.py index daa4a06d..fc2db421 100644 --- a/tests/qactivation_test.py +++ b/tests/qactivation_test.py @@ -22,10 +22,16 @@ import pytest from tensorflow.keras import backend as K +import tensorflow as tf +from keras.models import Sequential +from keras.callbacks import Callback +from keras.optimizers import SGD + from qkeras import set_internal_sigmoid from qkeras import binary from qkeras import hard_sigmoid +from qkeras import quantized_linear from qkeras import quantized_bits from qkeras import quantized_hswish from qkeras import quantized_po2 @@ -38,6 +44,8 @@ from qkeras import stochastic_ternary from qkeras import ternary from qkeras.quantizers import _default_sigmoid_type +from qkeras import QDense, QConv2D + @pytest.mark.parametrize( @@ -517,6 +525,371 @@ def test_quantized_bits(bits, integer, symmetric, keep_negative, test_values, result = f([test_values])[0] assert_allclose(result, expected_values, rtol=rtol) +class TestQuantizedLinear: + """Tests for quantized_linear""" + + def test_sign_function(self): + "Test to make sure that sign function is working properly" + + quantizer = quantized_linear(bits=1, keep_negative=True) + x = tf.constant([-1.0, 0.0, 1.0, 2.0, 3.0]) + + res = quantizer(x) + expected_res = quantizer.max() * tf.constant([-1.0, 1.0, 1.0, 1.0, 1.0]) + + assert tf.reduce_all(tf.equal(res, expected_res)) + + @pytest.mark.parametrize('shape', [(1,), (2, 3), (2, 2, 4)]) + def test_scale_shape(self, shape): + """Test to make sure that scale is the right shape for auto-alphas""" + + auto_quantizer = quantized_linear(alpha='auto') + auto_po2_quantizer = quantized_linear(alpha='auto_po2') + x = tf.ones(shape) + auto_quantizer(x) + auto_po2_quantizer(x) + + if len(shape) == 1: + expected_shape = (1,) + else: + ones = [1 for _ in range(len(shape) - 1)] + expected_shape = tuple(ones) + shape[-1:] + + assert auto_quantizer.scale.shape == expected_shape + assert auto_po2_quantizer.scale.shape == expected_shape + + @pytest.mark.parametrize( + 'inputs, expected_auto_scale, expected_auto_po2_scale', + [ + ( # rank 1 + [1.0, 2.0, 3.0, 4.0, 5.0], + [2.0, 4.0, 6.0, 8.0, 10.0], + [2.0, 4.0, 8.0, 8.0, 8.0] + ), + ( # rank 2 + [[-1.0, 0.0, 1.0, 2.5, 4.0], + [-1.0, 4.0, 5.0, -2.0, -1.0], + [ 1.0, 2.0, 3.0, 0.0, 20.0]], + [[ 2.0, 8.0, 10.0, 5.0, 40.0,]], + [[ 2.0, 8.0, 8.0, 4.0, 32.0,]], + ), + ( # rank 3 + [[[0.0, 1.0], [2.0, 3.0]], + [[4.0, 5.0], [6.0, 7.0]]], + [[[12.0, 14.0]]], + [[[16.0, 16.0]]], + ) + ] + ) + def test_scale_values( + self, inputs, expected_auto_scale, expected_auto_po2_scale): + """ + Test to make sure that scale is the right value for auto-alphas + + Note that since bits=2, the data type scale will be 0.5. This means that + the scale values will be 2x the quantization_scale values. + """ + + auto_quantizer = quantized_linear(alpha='auto', bits=2) + auto_po2_quantizer = quantized_linear(alpha='auto_po2', bits=2) + + auto_quantizer(inputs) + auto_po2_quantizer(inputs) + + tf.debugging.assert_equal(auto_quantizer.scale, expected_auto_scale) + tf.debugging.assert_equal(auto_po2_quantizer.scale, expected_auto_po2_scale) + + @pytest.mark.parametrize('layer_type', ['QDense', 'QConv2D']) + @pytest.mark.parametrize('alpha', ['auto', 'auto_po2']) + def test_training_eval_equivalence(self, layer_type, alpha): + """Test the behavior of quantizer during training and eval""" + + np.random.seed(42) + + # Define the quantizer + quantizer = quantized_linear(alpha=alpha) + model = Sequential() + + if layer_type == 'QConv2D': + input_shape = (28, 28, 1) # Example shape for a grayscale image + weight_shape = (3, 3, 1, 1) + conv_layer = QConv2D( + 1, (3, 3), kernel_quantizer=quantizer, use_bias=False) + model.add(conv_layer) + # Create fake data with the corresponding shape + X = np.random.rand(100, *input_shape) + elif layer_type == 'QDense': + input_shape = (2,) # Example shape for 2 input features + weight_shape = (2, 3) + dense_layer = QDense( + 3, input_shape=input_shape, kernel_quantizer=quantizer, use_bias=False) + model.add(dense_layer) + # Create fake data with the corresponding shape + X = np.random.rand(100, *input_shape) + + # Set learning rate to zero + opt = SGD(learning_rate=0.0) + model.compile(optimizer=opt, loss='mse') + + # Initialize the weights + model.build((None, *input_shape)) + initial_weights = np.random.rand(*weight_shape) + model.layers[0].set_weights([initial_weights]) + + # Create fake output data + output_shape = model.output_shape[1:] + y = np.random.rand(100, *output_shape) + + # Define a callback to capture weights during training + class CaptureQuantizedWeightsCallback(Callback): + def __init__(self): + + self.quantized_weights = [] + self.scales = [] + + def on_train_batch_begin(self, batch, logs=None): + weights = self.model.layers[0].get_weights()[0] + qweights = self.model.layers[0].kernel_quantizer_internal(weights) + scale = self.model.layers[0].kernel_quantizer_internal.scale + self.quantized_weights.append(qweights) + self.scales.append(scale) + + capture_weights_callback = CaptureQuantizedWeightsCallback() + + # Train the model + model.fit(X, y, epochs=1, callbacks=[capture_weights_callback]) + + # Capture the weights during evaluation (testing phase) + weights = model.layers[0].get_weights()[0] + eval_quantized_weights = model.layers[0].kernel_quantizer_internal(weights) + eval_scale = model.layers[0].kernel_quantizer_internal.scale + + # Compare the weights during training and evaluation + for train_quantized_weights in capture_weights_callback.quantized_weights: + assert np.allclose(train_quantized_weights, eval_quantized_weights) + + # Compare the scales during training and evaluation + for train_scale in capture_weights_callback.scales: + assert np.allclose(train_scale, eval_scale) + + @pytest.mark.parametrize('alpha', [None, 2.0]) + @pytest.mark.parametrize('symmetric,keep_negative', + [(True, True), (False, True), (False, False)]) + @pytest.mark.parametrize('bits', [1, 8]) + def test_gradient_formula(self, bits, symmetric, keep_negative, alpha): + """Test to make sure that the gradient formula is correct""" + + quantizer = quantized_linear(bits=bits, symmetric=symmetric, + keep_negative=keep_negative, alpha=alpha) + x = tf.Variable([-1.0, 0.0, 1.0, 2.0, 3.0]) + + with tf.GradientTape() as tape: + res = quantizer(x) + + grad = tape.gradient(res, x) + expected_grad = ((x >= quantizer.min()) & (x <= quantizer.max())) + expected_grad = K.cast_to_floatx(expected_grad) + assert grad is not None + tf.debugging.assert_equal(grad, expected_grad) + + @pytest.mark.parametrize( + 'keep_negative, bits, expected_gradient', + [(True, 1, [0, 1, 1, 1, 0]), + (True, 8, [0, 1, 1, 1, 0]), + (False, 1, [0, 0, 1, 1, 0]), + (False, 8, [0, 0, 1, 1, 0])] + ) + def test_gradients_explicit(self, keep_negative, bits, expected_gradient): + """Tests on specific gradient values""" + + inputs = tf.Variable([-1.1, -0.1, 0.0, 0.1, 1.1]) + expected_gradient = tf.Variable(expected_gradient) + + q = quantized_linear(bits=bits, keep_negative=keep_negative) + with tf.GradientTape() as tape: + tape.watch(inputs) + outputs = q(inputs) + + gradients = tape.gradient(outputs, inputs) + assert_allclose(gradients, expected_gradient) + +class TestBackwardsCompatibilityForQuantizedLinear: + """Regression tests for quantized_linear, comparing to quantized_bits""" + + QUANTIZED_BITS_PARAMS = { + "alpha": (None, "auto", "auto_po2"), + "bits": (1, 4, 8), + "integer": (0, 1), + "symmetric": (True, False), + "keep_negative": (True, False), + "qnoise_factor": (1.0, 0.5, 0.0), + "use_stochastic_rounding": (True, False), + "scale_axis": (None, 0, 1), + } + + TEST_X_VALUES = ( + 0, + *np.linspace(-2, 2, 10).tolist(), + tf.random.uniform((2, )), + tf.random.normal((2, 2)), + tf.random.normal((2, 2, 2)), + ) + + # get list of kwargs for test iteration + kwargs_list = [] + + for param_name, param_values in QUANTIZED_BITS_PARAMS.items(): + for param_value in param_values: + kwargs = {param_name: param_value} + kwargs_list.append(kwargs) + + # extra kwargs for special cases + extra_kwargs_list = [ + { + "alpha": "auto", + "symmetric": True, + "keep_negative": True + }, + { + "alpha": "auto_po2", + "symmetric": True, + "keep_negative": True + }, + { + "alpha": "auto", + "symmetric": True, + "keep_negative": True, + "integer": 2 + }, + { + "alpha": "auto_po2", + "symmetric": True, + "keep_negative": True, + "integer": 2 + }, + ] + + kwargs_list = extra_kwargs_list + kwargs_list + + @pytest.mark.parametrize('kwargs', kwargs_list) + def test_regression(self, kwargs): + """Check that the alt_quantized_bits and qkeras.quantized_bits + return the same result for all test values""" + + bits = kwargs.get("bits", 8) + integer = kwargs.get("integer", 0) + keep_negative = kwargs.get("keep_negative", True) + alpha = kwargs.get("alpha", None) + symmetric = kwargs.get("symmetric", True) + # defaults for quantized_bits and quantized_linear are different, need to + # specify in kwargs + kwargs["symmetric"] = symmetric + # variable to determine if checking for errors only, not correctness + check_errors_only = False + + # decidedly raises an error + if bits < integer + keep_negative: + return + # Not implemented in quantized_bits + if alpha in ("auto", "auto_po2") and (not symmetric or not keep_negative): + check_errors_only = True + # bug in quantized_bits + if bits - keep_negative == 0 and alpha in ("auto", "auto_po2"): + check_errors_only = True + # new implementation in quantized_linear + if bits == 1 and keep_negative: + check_errors_only = True + + old = quantized_bits(**kwargs) + new = quantized_linear(**kwargs) + + for x in self.TEST_X_VALUES: + # reset variable in loop + check_errors_only_ = check_errors_only + # bug in quantized_bits + if tf.rank(x) == 0 and alpha in ("auto", "auto_po2"): + continue + # Changed default scale axis for rank-1 tensors + if tf.rank(x) == 1 and alpha in ("auto", "auto_po2"): + check_errors_only_ = True + self._check_correctness(new, old, x, kwargs, + check_errors_only=check_errors_only_) + + @pytest.mark.parametrize('kwargs', kwargs_list) + def test_config(self, kwargs): + + symmetric = kwargs.get("symmetric", True) + # defaults for quantized_bits and quantized_linear are different, need to + # specify in kwargs + kwargs["symmetric"] = symmetric + + old = quantized_bits(**kwargs) + new = quantized_linear(**kwargs) + + assert old.get_config() == new.get_config() + + @pytest.mark.parametrize('kwargs', kwargs_list) + def test_string(self, kwargs): + + symmetric = kwargs.get("symmetric", True) + # defaults for quantized_bits and quantized_linear are different, need to + # specify in kwargs + kwargs["symmetric"] = symmetric + + old = quantized_bits(**kwargs) + new = quantized_linear(**kwargs) + + old_str = str(old) + new_str = str(new) + + old_str = old_str.replace("quantized_bits", "quantized_linear") + + assert old_str == new_str + + @pytest.mark.parametrize('qnoise_factor', [0.0, 0.5, 1.0]) + def test_qnoise_and_gradient(self, qnoise_factor): + """Make sure that gradient calculations vary w.r.t. qnoise correctly""" + + qlinear = quantized_linear(qnoise_factor=qnoise_factor) + qbits = quantized_bits(qnoise_factor=qnoise_factor) + + x = np.linspace( + qlinear.min() + K.epsilon(), + qlinear.max() - K.epsilon(), + 10) + x = tf.Variable(x) + + with tf.GradientTape() as tape: + res_linear = qlinear(x) + + grad_linear = tape.gradient(res_linear, x) + + with tf.GradientTape() as tape: + res_bits = qbits(x) + + grad_bits = tape.gradient(res_bits, x) + tf.debugging.assert_equal(grad_linear, grad_bits) + + def _check_correctness(self, new_func, old_func, x, kwargs, + check_errors_only=False): + """Check that the new_func and old_func return the same result for x""" + + old_res = old_func(x).numpy() + new_res = new_func(x).numpy() + old_scale = np.array(old_func.scale) + new_scale = np.array(new_func.scale) + + # not checking if new matches old + if check_errors_only: + return + err_msg = (f"Failed for {kwargs} with x = {x}. \n" + f"old_res = {old_res}, alt_res = {new_res}. \n" + f"old_scale = {old_scale}, new_scale = {new_scale}") + if not np.allclose(old_res, new_res): + assert False, err_msg + if not np.allclose(old_scale, new_scale) and K.max(x) > 0: + assert False, err_msg + @pytest.mark.parametrize('alpha, threshold, test_values, expected_values', [ (1.0, 0.33, diff --git a/tests/qnoise_test.py b/tests/qnoise_test.py index b49531d7..1e442b57 100644 --- a/tests/qnoise_test.py +++ b/tests/qnoise_test.py @@ -24,10 +24,14 @@ import pytest from tensorflow.keras import backend as K from qkeras.quantizers import quantized_bits +from qkeras.quantizers import quantized_linear from qkeras.quantizers import quantized_relu -def test_qnoise_quantized_bits(): +@pytest.mark.parametrize('quantizer', [quantized_bits, quantized_linear]) +def test_qnoise_linear_quantizer(quantizer): + """Tests for quantized_bits and quantized_linear.""" + # 1 sign bit, 1 integer bit, and 2 fractional bits. bits = 4 integer = 1 @@ -36,14 +40,15 @@ def test_qnoise_quantized_bits(): alpha = 1 use_stochastic_rounding = False - qb = quantized_bits( + q = quantizer( bits=bits, integer=integer, symmetric=symmetric, keep_negative=keep_negative, alpha=alpha, use_stochastic_rounding=use_stochastic_rounding, - use_variables=True) + use_variables=True, + ) inputs = np.array([0.0, 0.5, -0.5, 0.6, -0.6, 2.0, -2.0], dtype=np.float32) x = np.array([0.0, 0.5, -0.5, 0.6, -0.6, 2.0, -2.0], dtype=np.float32) @@ -51,18 +56,18 @@ def test_qnoise_quantized_bits(): x_xq = 0.5 * (x + xq) # no quantization - qb.update_qnoise_factor(qnoise_factor=0.0) - x_q_0 = qb(inputs) + q.update_qnoise_factor(0.0) + x_q_0 = q(inputs) assert_equal(x_q_0, x) # full quantization - qb.update_qnoise_factor(qnoise_factor=1.0) - x_q_1 = qb(inputs) + q.update_qnoise_factor(1.0) + x_q_1 = q(inputs) assert_equal(x_q_1, xq) - # mixing half and half of x and xq - qb.update_qnoise_factor(qnoise_factor=0.5) - x_q_05 = qb(inputs) + # mixing half and half of x and gxq + q.update_qnoise_factor(0.5) + x_q_05 = q(inputs) assert_equal(x_q_05, x_xq) diff --git a/tests/qtools_model_test.py b/tests/qtools_model_test.py index cddd0a80..f27f8925 100644 --- a/tests/qtools_model_test.py +++ b/tests/qtools_model_test.py @@ -16,6 +16,7 @@ """Tests for various model architectures.""" import json +from collections import OrderedDict import numpy as np import pytest @@ -68,7 +69,7 @@ def qdense_model_fork(): return model -def qconv_model(): +def qconv_model(quantizer): x = x_in = keras.layers.Input((23, 23, 1), name="input") x = QActivation("quantized_relu(4)", name="QA_0")(x) x = QConv2D( @@ -78,17 +79,17 @@ def qconv_model(): name="qconv2d_1")(x) x = QConv2D( 8, 2, 2, - kernel_quantizer=quantizers.quantized_bits(4, 0, 1), - bias_quantizer=quantizers.quantized_bits(4, 0, 1), + kernel_quantizer=quantizer(4, 0, 1), + bias_quantizer=quantizer(4, 0, 1), activation=quantizers.quantized_relu(6, 2), name="qconv2D_2")(x) x = QConv2D( 2, 2, 2, - kernel_quantizer=quantizers.quantized_bits(4, 0, 1), - bias_quantizer=quantizers.quantized_bits(4, 0, 1), + kernel_quantizer=quantizer(4, 0, 1), + bias_quantizer=quantizer(4, 0, 1), activation=quantizers.quantized_relu(6, 2), name="qconv2d_3")(x) - x = QActivation("quantized_bits(6, 0, 1)", name="QA_4")(x) + x = QActivation(quantizer(6, 0, 1), name="QA_4")(x) model = keras.Model( inputs=[x_in], outputs=[x]) @@ -946,6 +947,62 @@ def test_qdepthwiseconv2d(): assert dtype_dict["pw_conv"]["accumulator"]["bits"] == 28 assert dtype_dict["pw_conv"]["accumulator"]["int_bits"] == 11 +def test_quantized_linear_backwards_compatibility(): + + def get_output_dict(model, quantizer): + """Get output dict from qtools""" + + input_quantizer_list = [quantizer()] + reference_internal = "int8" + reference_accumulator = "int32" + + # generate QTools object which contains model data type map in json format + q = run_qtools.QTools( + model, + # energy calculation using a given process + process="horowitz", + # quantizers for model inputs + source_quantizers=input_quantizer_list, + # training or inference with a pre-trained model + is_inference=False, + # path to pre-trained model weights + weights_path=None, + # keras_quantizer to quantize weight/bias in non-quantized keras layers + keras_quantizer=reference_internal, + # keras_accumulator to quantize MAC in un-quantized keras layers + keras_accumulator=reference_accumulator, + # calculating baseline energy or not + for_reference=False) + + return q._output_dict + + qbits_model = qconv_model(quantizers.quantized_bits) + qlinear_model = qconv_model(quantizers.quantized_linear) + + qbits_output_dict = get_output_dict( + qbits_model, quantizers.quantized_bits) + qlinear_output_dict = get_output_dict( + qlinear_model, quantizers.quantized_linear) + + def assert_output_dict_equal(qbits_output, qlinear_output): + # Check if the output dict of qbits and qlinear are the same + + if isinstance(qbits_output, OrderedDict): + assert isinstance(qlinear_output, OrderedDict) + for key in qbits_output: + assert key in qlinear_output + assert_output_dict_equal(qbits_output[key], qlinear_output[key]) + elif isinstance(qbits_output, list): + assert isinstance(qlinear_output, list) + for i in range(len(qbits_output)): + assert_output_dict_equal(qbits_output[i], qlinear_output[i]) + else: + if qbits_output == 'quantized_bits': + assert qlinear_output in ('quantized_linear', 'quantized_bits') + else: + assert qbits_output == qlinear_output + + assert_output_dict_equal(qbits_output_dict, qlinear_output_dict) def test_divide_and_conquer_sequential_conv2d(): # These following values are verified manually to be globally optimal. diff --git a/tests/quantizer_impl_test.py b/tests/quantizer_impl_test.py index a4aa82e6..95560ac1 100644 --- a/tests/quantizer_impl_test.py +++ b/tests/quantizer_impl_test.py @@ -241,9 +241,9 @@ def test_GetScale_PerChannelScale(): # Rank1 tensors x_r1 = tf.ones([4]) q_r1 = tf.ones([4]) - scale_r1_pcs_true = quantizers._get_scale( + scale_r1_pcs_true = quantizers._get_least_squares_scale( "auto", x_r1, q_r1, scale_axis=None, per_channel_scale=True) - scale_r1_pcs_false = quantizers._get_scale( + scale_r1_pcs_false = quantizers._get_least_squares_scale( "auto", x_r1, q_r1, scale_axis=None, per_channel_scale=False) assert_equal(tf.shape(scale_r1_pcs_true).numpy(), [4]) assert_equal(tf.shape(scale_r1_pcs_false).numpy(), [1]) @@ -251,9 +251,9 @@ def test_GetScale_PerChannelScale(): # Rank2 tensors x_r2 = tf.ones([2, 4]) q_r2 = tf.ones([2, 4]) - scale_r2_pcs_true = quantizers._get_scale( + scale_r2_pcs_true = quantizers._get_least_squares_scale( "auto", x_r2, q_r2, scale_axis=None, per_channel_scale=True) - scale_r2_pcs_false = quantizers._get_scale( + scale_r2_pcs_false = quantizers._get_least_squares_scale( "auto", x_r2, q_r2, scale_axis=None, per_channel_scale=False) assert_equal(tf.shape(scale_r2_pcs_true).numpy(), [1, 4]) assert_equal(tf.shape(scale_r2_pcs_false).numpy(), [1, 1]) @@ -261,9 +261,9 @@ def test_GetScale_PerChannelScale(): # Rank3 tensors x_r3 = tf.ones([3, 3, 4]) q_r3 = tf.ones([3, 3, 4]) - scale_r3_pcs_true = quantizers._get_scale( + scale_r3_pcs_true = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, scale_axis=None, per_channel_scale=True) - scale_r3_pcs_false = quantizers._get_scale( + scale_r3_pcs_false = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, scale_axis=None, per_channel_scale=False) assert_equal(tf.shape(scale_r3_pcs_true).numpy(), [1, 1, 4]) assert_equal(tf.shape(scale_r3_pcs_false).numpy(), [1, 1, 1]) @@ -271,9 +271,9 @@ def test_GetScale_PerChannelScale(): # Rank4 tensors x_r4 = tf.ones([1, 1, 3, 4]) q_r4 = tf.ones([1, 1, 3, 4]) - scale_r4_pcs_true = quantizers._get_scale( + scale_r4_pcs_true = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, scale_axis=None, per_channel_scale=True) - scale_r4_pcs_false = quantizers._get_scale( + scale_r4_pcs_false = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, scale_axis=None, per_channel_scale=False) assert_equal(tf.shape(scale_r4_pcs_true).numpy(), [1, 1, 1, 4]) assert_equal(tf.shape(scale_r4_pcs_false).numpy(), [1, 1, 1, 1]) @@ -288,11 +288,11 @@ def test_GetScale_ElementsPerScale_Scalar_ScaleAxis_EPS(): # values and the input x and q tensors have rank 2 x_r2 = tf.random.uniform([4, 8]) q_r2 = tf.random.uniform([4, 8]) - scale_r2_eps_none_ua_none = quantizers._get_scale( + scale_r2_eps_none_ua_none = quantizers._get_least_squares_scale( "auto", x_r2, q_r2, elements_per_scale=None, scale_axis=None) - scale_r2_eps_2_ua_0 = quantizers._get_scale( + scale_r2_eps_2_ua_0 = quantizers._get_least_squares_scale( "auto", x_r2, q_r2, elements_per_scale=2, scale_axis=0) - scale_r2_eps_2_ua_1 = quantizers._get_scale( + scale_r2_eps_2_ua_1 = quantizers._get_least_squares_scale( "auto", x_r2, q_r2, elements_per_scale=2, scale_axis=1) assert_equal(tf.shape(scale_r2_eps_none_ua_none).numpy(), [1, 8]) @@ -308,13 +308,13 @@ def test_GetScale_ElementsPerScale_Scalar_ScaleAxis_EPS(): # values and the input x and q tensors have rank 3 x_r3 = tf.random.uniform([2, 4, 8]) q_r3 = tf.random.uniform([2, 4, 8]) - scale_r3_eps_none_ua_none = quantizers._get_scale( + scale_r3_eps_none_ua_none = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=None, scale_axis=None) - scale_r3_eps_2_ua_0 = quantizers._get_scale( + scale_r3_eps_2_ua_0 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=2, scale_axis=0) - scale_r3_eps_2_ua_1 = quantizers._get_scale( + scale_r3_eps_2_ua_1 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=2, scale_axis=1) - scale_r3_eps_2_ua_2 = quantizers._get_scale( + scale_r3_eps_2_ua_2 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=2, scale_axis=2) assert_equal(tf.shape(scale_r3_eps_none_ua_none).numpy(), [1, 1, 8]) @@ -333,15 +333,15 @@ def test_GetScale_ElementsPerScale_Scalar_ScaleAxis_EPS(): # values and the input x and q tensors have rank 4 x_r4 = tf.random.uniform([2, 4, 8, 16]) q_r4 = tf.random.uniform([2, 4, 8, 16]) - scale_r4_eps_none_ua_none = quantizers._get_scale( + scale_r4_eps_none_ua_none = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=None, scale_axis=None) - scale_r4_eps_2_ua_0 = quantizers._get_scale( + scale_r4_eps_2_ua_0 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=2, scale_axis=0) - scale_r4_eps_2_ua_1 = quantizers._get_scale( + scale_r4_eps_2_ua_1 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=2, scale_axis=1) - scale_r4_eps_2_ua_2 = quantizers._get_scale( + scale_r4_eps_2_ua_2 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=2, scale_axis=2) - scale_r4_eps_2_ua_3 = quantizers._get_scale( + scale_r4_eps_2_ua_3 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=2, scale_axis=3) assert_equal(tf.shape(scale_r4_eps_none_ua_none).numpy(), [1, 1, 1, 16]) @@ -366,13 +366,13 @@ def test_GetScale_ElementsPerScale_List_ScaleAxis_EPS(): x_r3 = tf.random.uniform([2, 4, 8]) q_r3 = tf.random.uniform([2, 4, 8]) - scale_r3_eps_none_ua_0 = quantizers._get_scale( + scale_r3_eps_none_ua_0 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=None, scale_axis=[0]) - scale_r3_eps_2_ua_0 = quantizers._get_scale( + scale_r3_eps_2_ua_0 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=[2], scale_axis=[0]) - scale_r3_eps_2_ua_1 = quantizers._get_scale( + scale_r3_eps_2_ua_1 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=[2], scale_axis=[1]) - scale_r3_eps_2_ua_2 = quantizers._get_scale( + scale_r3_eps_2_ua_2 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=[2], scale_axis=[2]) assert_equal(tf.shape(scale_r3_eps_none_ua_0).numpy(), [2, 1, 1]) @@ -392,13 +392,13 @@ def test_GetScale_ElementsPerScale_List_ScaleAxis_EPS(): x_r3 = tf.random.uniform([2, 4, 8]) q_r3 = tf.random.uniform([2, 4, 8]) - scale_r3_eps_none_ua_01 = quantizers._get_scale( + scale_r3_eps_none_ua_01 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=None, scale_axis=[0, 1]) - scale_r3_eps_22_ua_01 = quantizers._get_scale( + scale_r3_eps_22_ua_01 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=[2, 2], scale_axis=[0, 1]) - scale_r3_eps_11_ua_12 = quantizers._get_scale( + scale_r3_eps_11_ua_12 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=[2, 2], scale_axis=[1, 2]) - scale_r3_eps_11_ua_02 = quantizers._get_scale( + scale_r3_eps_11_ua_02 = quantizers._get_least_squares_scale( "auto", x_r3, q_r3, elements_per_scale=[1, 1], scale_axis=[0, 2]) assert_equal(tf.shape(scale_r3_eps_none_ua_01).numpy(), [2, 4, 1]) @@ -418,13 +418,13 @@ def test_GetScale_ElementsPerScale_List_ScaleAxis_EPS(): x_r4 = tf.random.uniform([2, 4, 8, 16]) q_r4 = tf.random.uniform([2, 4, 8, 16]) - scale_r4_eps_none_ua_012 = quantizers._get_scale( + scale_r4_eps_none_ua_012 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=None, scale_axis=[0, 1, 2]) - scale_r4_eps_221_ua_012 = quantizers._get_scale( + scale_r4_eps_221_ua_012 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=[2, 2, 1], scale_axis=[0, 1, 2]) - scale_r4_eps_221_ua_123 = quantizers._get_scale( + scale_r4_eps_221_ua_123 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=[2, 2, 1], scale_axis=[1, 2, 3]) - scale_r4_eps_221_ua_013 = quantizers._get_scale( + scale_r4_eps_221_ua_013 = quantizers._get_least_squares_scale( "auto", x_r4, q_r4, elements_per_scale=[2, 2, 1], scale_axis=[0, 1, 3]) assert_equal(tf.shape(scale_r4_eps_none_ua_012).numpy(), [2, 4, 8, 1]) @@ -454,15 +454,15 @@ def _get_min_max_po2_exponent(x): q = 2**tf.random.uniform(shape=[2, 4, 8], minval=-50, maxval=0) # set various min and max po2 exponents for the scale - scale_min_neg3_max_1 = quantizers._get_scale( + scale_min_neg3_max_1 = quantizers._get_least_squares_scale( "auto_po2", x, q, elements_per_scale=4, scale_axis=2, min_po2_exponent=-3, max_po2_exponent=1) - scale_min_neg8_max_0 = quantizers._get_scale( + scale_min_neg8_max_0 = quantizers._get_least_squares_scale( "auto_po2", x, q, elements_per_scale=4, scale_axis=2, min_po2_exponent=-8, max_po2_exponent=0) - scale_min_neg10_max_1 = quantizers._get_scale( + scale_min_neg10_max_1 = quantizers._get_least_squares_scale( "auto_po2", x, q, elements_per_scale=4, scale_axis=2, min_po2_exponent=-10, max_po2_exponent=1) diff --git a/tests/range_test.py b/tests/range_test.py index 6339e15b..7769fa31 100644 --- a/tests/range_test.py +++ b/tests/range_test.py @@ -22,9 +22,11 @@ import pytest from tensorflow.keras import backend as K +import tensorflow as tf from qkeras import quantized_relu from qkeras import quantized_bits +from qkeras import quantized_linear @pytest.mark.parametrize( @@ -73,6 +75,53 @@ def test_quantized_bits_range(bits, integer, expected_values): result = q.range() assert_allclose(result, expected_values, rtol=1e-05) +@pytest.mark.parametrize('alpha', [None, 2.0]) +@pytest.mark.parametrize('symmetric,keep_negative', + [(True, True), (False, True), (False, False)]) +@pytest.mark.parametrize('bits', [1, 8]) +def test_quantized_linear_range(bits, symmetric, keep_negative, alpha): + """Test quantized_linear range function.""" + q = quantized_linear(bits, 0, symmetric=symmetric, keep_negative=keep_negative, + alpha=alpha) + # compute output on array of inputs, and compare to q.range() + x = np.linspace(-10.0, 10.0, 10 * 2**(bits + 1) + 1) + y = q(x) + q_range = q.range() + # assert that y and q_range have the same set of values + _assert_same_unique_values(q_range, y) + # assert that values ordered on binary range asending + _assert_binary_range_ordering(q_range) + + +def _assert_same_unique_values(x, y): + """Check if two TensorFlow tensors have the same unique set of values.""" + # Get the unique values of each tensor + unique_x = tf.unique(x)[0].numpy() + unique_y = tf.unique(y)[0].numpy() + + # sort the unique values + unique_x.sort() + unique_y.sort() + + assert unique_x.shape == unique_y.shape + assert np.allclose(unique_x, unique_y) + + +def _assert_binary_range_ordering(x): + """Assert that x is ordered by binary representation ascending""" + + x = np.array(x) + # get positive values in x + x_pos = x[x >= 0] + # get negative values in x + x_neg = x[x < 0] + # assert that positive values are ordered ascending + assert np.all(np.diff(x_pos) >= 0) + # assert that negative values are ordered ascending + assert np.all(np.diff(x_neg) >= 0) + # assert that all positive values come before negative values + assert np.all(x == np.concatenate([x_pos, x_neg])) + if __name__ == "__main__": pytest.main([__file__])