From a5b5365ef47c1dab818f6028ee6c0bf9be7a389e Mon Sep 17 00:00:00 2001 From: Elad Cohen <54019246+eladc-git@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:42:46 +0200 Subject: [PATCH] Fix MP error function with axis (#991) removes the batch list for MP calculation and use the internal batch flag in the error function. Special care for KL divergence error function. --- .../mixed_precision/sensitivity_evaluation.py | 35 ++++++++++--------- .../core/common/similarity_analyzer.py | 8 ++--- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py index 63a47ea7b..a1d1dc520 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py @@ -17,11 +17,11 @@ import numpy as np from typing import Callable, Any, List, Tuple -from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA +from model_compression_toolkit.constants import AXIS from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode - +from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, \ @@ -90,10 +90,10 @@ def __init__(self, fw_impl.count_node_for_mixed_precision_interest_points, quant_config.num_interest_points_factor) - self.ips_distance_fns, self.ips_batch_axis = self._init_metric_points_lists(self.interest_points) + self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points) self.output_points = get_output_nodes_for_metric(graph) - self.out_ps_distance_fns, self.out_ps_batch_axis = self._init_metric_points_lists(self.output_points) + self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points) # Setting lists with relative position of the interest points # and output points in the list of all mp model activation tensors @@ -133,26 +133,27 @@ def _init_metric_points_lists(self, points: List[BaseNode]) -> Tuple[List[Callab """ Initiates required lists for future use when computing the sensitivity metric. Each point on which the metric is computed uses a dedicated distance function based on its type. - In addition, all distance functions preform batch computation, so the batch axis is needed for each node. + In addition, all distance functions preform batch computation. Axis is needed only for KL Divergence computation. Args: points: The set of nodes in the graph for which we need to initiate the lists. - Returns: A lists with distance functions and a list batch axis for each node. + Returns: A lists with distance functions and an axis list for each node. """ distance_fns_list = [] - batch_axis_list = [] + axis_list = [] for n in points: - axis = n.framework_attr.get(AXIS) if not isinstance(n, FunctionalNode) else n.op_call_kwargs.get(AXIS) - distance_fns_list.append(self.fw_impl.get_node_distance_fn( + distance_fn = self.fw_impl.get_node_distance_fn( layer_class=n.layer_class, framework_attrs=n.framework_attr, compute_distance_fn=self.quant_config.compute_distance_fn, - axis=axis)) - batch_axis_list.append(axis) - return distance_fns_list, batch_axis_list + axis=axis) + distance_fns_list.append(distance_fn) + # Axis is needed only for KL Divergence calculation, otherwise we use per-tensor computation + axis_list.append(axis if distance_fn==compute_kl_divergence else None) + return distance_fns_list, axis_list def compute_metric(self, mp_model_configuration: List[int], @@ -329,7 +330,7 @@ def _compute_points_distance(self, baseline_tensors: List[Any], mp_tensors: List[Any], points_distance_fns: List[Callable], - points_batch_axis: List[int]): + points_axis: List[int]): """ Compute the distance on the given set of points outputs between the MP model and the baseline model for each image in the batch that was inferred. @@ -339,7 +340,7 @@ def _compute_points_distance(self, mp_tensors: MP model's output tensors pf the given points. points_distance_fns: A list with distance function to compute the distance between each given point's output tensors. - points_batch_axis: A list with the matching batch axis of each given point's output tensors. + points_axis: A list with the matching axis of each given point's output tensors. Returns: A distance vector that maps each node's index in the given nodes list to the distance between this node's output @@ -347,7 +348,7 @@ def _compute_points_distance(self, """ distance_v = [fn(x, y, batch=True, axis=axis) for fn, x, y, axis - in zip(points_distance_fns, baseline_tensors, mp_tensors, points_batch_axis)] + in zip(points_distance_fns, baseline_tensors, mp_tensors, points_axis)] return np.asarray(distance_v) @@ -373,11 +374,11 @@ def _compute_distance(self) -> Tuple[np.ndarray, np.ndarray]: ips_distance = self._compute_points_distance([baseline_tensors[i] for i in self.ips_act_indices], [mp_tensors[i] for i in self.ips_act_indices], self.ips_distance_fns, - self.ips_batch_axis) + self.ips_axis) outputs_distance = self._compute_points_distance([baseline_tensors[i] for i in self.out_ps_act_indices], [mp_tensors[i] for i in self.out_ps_act_indices], self.out_ps_distance_fns, - self.out_ps_batch_axis) + self.out_ps_axis) # Extending the dimensions for the concatenation at the end in case we need to ips_distance = ips_distance if len(ips_distance.shape) > 1 else ips_distance[:, None] diff --git a/model_compression_toolkit/core/common/similarity_analyzer.py b/model_compression_toolkit/core/common/similarity_analyzer.py index 2caccbea2..3ed30f4aa 100644 --- a/model_compression_toolkit/core/common/similarity_analyzer.py +++ b/model_compression_toolkit/core/common/similarity_analyzer.py @@ -97,7 +97,7 @@ def compute_mse(float_tensor: np.ndarray, norm: whether to normalize the error function result. norm_eps: epsilon value for error normalization stability. batch: Whether to run batch similarity analysis or not. - axis: Axis along which the operator has been computed (not used in this function). + axis: Axis along which the operator has been computed. Returns: The MSE distance between the two tensors. @@ -129,7 +129,7 @@ def compute_mae(float_tensor: np.ndarray, norm: whether to normalize the error function result. norm_eps: epsilon value for error normalization stability. batch: Whether to run batch similarity analysis or not. - axis: Axis along which the operator has been computed (not used in this function). + axis: Axis along which the operator has been computed. Returns: The mean average distance between the two tensors. @@ -158,7 +158,7 @@ def compute_cs(float_tensor: np.ndarray, fxp_tensor: np.ndarray, eps: float = 1e fxp_tensor: Second tensor to compare. eps: Small value to avoid zero division. batch: Whether to run batch similarity analysis or not. - axis: Axis along which the operator has been computed (not used in this function). + axis: Axis along which the operator has been computed. Returns: The cosine similarity between two tensors. @@ -200,7 +200,7 @@ def compute_lp_norm(float_tensor: np.ndarray, norm: whether to normalize the error function result. norm_eps: epsilon value for error normalization stability. batch: Whether to run batch similarity analysis or not. - axis: Axis along which the operator has been computed (not used in this function). + axis: Axis along which the operator has been computed. Returns: The Lp-norm distance between the two tensors.