Skip to content

Commit

Permalink
Fix MP error function with axis (#991)
Browse files Browse the repository at this point in the history
removes the batch list for MP calculation and use the internal batch flag in the error function. Special care for KL divergence error function.
  • Loading branch information
eladc-git authored Mar 12, 2024
1 parent 594ce9d commit a5b5365
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
import numpy as np
from typing import Callable, Any, List, Tuple

from model_compression_toolkit.constants import AXIS, HESSIAN_OUTPUT_ALPHA
from model_compression_toolkit.constants import AXIS
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode

from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, \
Expand Down Expand Up @@ -90,10 +90,10 @@ def __init__(self,
fw_impl.count_node_for_mixed_precision_interest_points,
quant_config.num_interest_points_factor)

self.ips_distance_fns, self.ips_batch_axis = self._init_metric_points_lists(self.interest_points)
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points)

self.output_points = get_output_nodes_for_metric(graph)
self.out_ps_distance_fns, self.out_ps_batch_axis = self._init_metric_points_lists(self.output_points)
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points)

# Setting lists with relative position of the interest points
# and output points in the list of all mp model activation tensors
Expand Down Expand Up @@ -133,26 +133,27 @@ def _init_metric_points_lists(self, points: List[BaseNode]) -> Tuple[List[Callab
"""
Initiates required lists for future use when computing the sensitivity metric.
Each point on which the metric is computed uses a dedicated distance function based on its type.
In addition, all distance functions preform batch computation, so the batch axis is needed for each node.
In addition, all distance functions preform batch computation. Axis is needed only for KL Divergence computation.
Args:
points: The set of nodes in the graph for which we need to initiate the lists.
Returns: A lists with distance functions and a list batch axis for each node.
Returns: A lists with distance functions and an axis list for each node.
"""
distance_fns_list = []
batch_axis_list = []
axis_list = []
for n in points:

axis = n.framework_attr.get(AXIS) if not isinstance(n, FunctionalNode) else n.op_call_kwargs.get(AXIS)
distance_fns_list.append(self.fw_impl.get_node_distance_fn(
distance_fn = self.fw_impl.get_node_distance_fn(
layer_class=n.layer_class,
framework_attrs=n.framework_attr,
compute_distance_fn=self.quant_config.compute_distance_fn,
axis=axis))
batch_axis_list.append(axis)
return distance_fns_list, batch_axis_list
axis=axis)
distance_fns_list.append(distance_fn)
# Axis is needed only for KL Divergence calculation, otherwise we use per-tensor computation
axis_list.append(axis if distance_fn==compute_kl_divergence else None)
return distance_fns_list, axis_list

def compute_metric(self,
mp_model_configuration: List[int],
Expand Down Expand Up @@ -329,7 +330,7 @@ def _compute_points_distance(self,
baseline_tensors: List[Any],
mp_tensors: List[Any],
points_distance_fns: List[Callable],
points_batch_axis: List[int]):
points_axis: List[int]):
"""
Compute the distance on the given set of points outputs between the MP model and the baseline model
for each image in the batch that was inferred.
Expand All @@ -339,15 +340,15 @@ def _compute_points_distance(self,
mp_tensors: MP model's output tensors pf the given points.
points_distance_fns: A list with distance function to compute the distance between each given
point's output tensors.
points_batch_axis: A list with the matching batch axis of each given point's output tensors.
points_axis: A list with the matching axis of each given point's output tensors.
Returns:
A distance vector that maps each node's index in the given nodes list to the distance between this node's output
and the baseline model's output for all images that were inferred.
"""

distance_v = [fn(x, y, batch=True, axis=axis) for fn, x, y, axis
in zip(points_distance_fns, baseline_tensors, mp_tensors, points_batch_axis)]
in zip(points_distance_fns, baseline_tensors, mp_tensors, points_axis)]

return np.asarray(distance_v)

Expand All @@ -373,11 +374,11 @@ def _compute_distance(self) -> Tuple[np.ndarray, np.ndarray]:
ips_distance = self._compute_points_distance([baseline_tensors[i] for i in self.ips_act_indices],
[mp_tensors[i] for i in self.ips_act_indices],
self.ips_distance_fns,
self.ips_batch_axis)
self.ips_axis)
outputs_distance = self._compute_points_distance([baseline_tensors[i] for i in self.out_ps_act_indices],
[mp_tensors[i] for i in self.out_ps_act_indices],
self.out_ps_distance_fns,
self.out_ps_batch_axis)
self.out_ps_axis)

# Extending the dimensions for the concatenation at the end in case we need to
ips_distance = ips_distance if len(ips_distance.shape) > 1 else ips_distance[:, None]
Expand Down
8 changes: 4 additions & 4 deletions model_compression_toolkit/core/common/similarity_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def compute_mse(float_tensor: np.ndarray,
norm: whether to normalize the error function result.
norm_eps: epsilon value for error normalization stability.
batch: Whether to run batch similarity analysis or not.
axis: Axis along which the operator has been computed (not used in this function).
axis: Axis along which the operator has been computed.
Returns:
The MSE distance between the two tensors.
Expand Down Expand Up @@ -129,7 +129,7 @@ def compute_mae(float_tensor: np.ndarray,
norm: whether to normalize the error function result.
norm_eps: epsilon value for error normalization stability.
batch: Whether to run batch similarity analysis or not.
axis: Axis along which the operator has been computed (not used in this function).
axis: Axis along which the operator has been computed.
Returns:
The mean average distance between the two tensors.
Expand Down Expand Up @@ -158,7 +158,7 @@ def compute_cs(float_tensor: np.ndarray, fxp_tensor: np.ndarray, eps: float = 1e
fxp_tensor: Second tensor to compare.
eps: Small value to avoid zero division.
batch: Whether to run batch similarity analysis or not.
axis: Axis along which the operator has been computed (not used in this function).
axis: Axis along which the operator has been computed.
Returns:
The cosine similarity between two tensors.
Expand Down Expand Up @@ -200,7 +200,7 @@ def compute_lp_norm(float_tensor: np.ndarray,
norm: whether to normalize the error function result.
norm_eps: epsilon value for error normalization stability.
batch: Whether to run batch similarity analysis or not.
axis: Axis along which the operator has been computed (not used in this function).
axis: Axis along which the operator has been computed.
Returns:
The Lp-norm distance between the two tensors.
Expand Down

0 comments on commit a5b5365

Please sign in to comment.