Skip to content

Commit

Permalink
Fix metric plotter
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle committed Nov 23, 2024
1 parent 3985b05 commit 45ae679
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions netspresso/utils/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ def _add_value_annotations(ax, bar, value):
)

@staticmethod
def _add_difference_annotations(ax, original_value, compressed_value, difference):
def _add_difference_annotations(ax, original_value, compressed_value, difference, scatter_names = None):
if scatter_names is None:
scatter_names = ["Original Model", "Compressed Model"]
ax.scatter(
["Original Model", "Compressed Model"],
scatter_names,
[original_value, compressed_value],
color="red",
marker="o",
zorder=3,
)

ax.plot(
["Original Model", "Compressed Model"],
scatter_names,
[original_value, compressed_value],
color="red",
linestyle="--",
Expand Down Expand Up @@ -193,6 +195,7 @@ def compare_metrics(metric_data1, metric_data2, metric_labels: List[str]):
metric_data1[_metric],
metric_data2[_metric],
metric_data2[_metric] - metric_data2[_metric],
["Original Model", "Quantized Model"],
)

Plotter._set_common_plot_settings(axs[idx], metric)
Expand Down

0 comments on commit 45ae679

Please sign in to comment.