Skip to content

Commit

Permalink
Prettify benchmark result
Browse files Browse the repository at this point in the history
  • Loading branch information
ankandrew committed Oct 14, 2024
1 parent 82666fb commit 83e8b3a
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions fast_plate_ocr/inference/onnx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import numpy.typing as npt
import onnxruntime as ort
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text

from fast_plate_ocr.common.utils import measure_time
from fast_plate_ocr.inference import hub
Expand Down Expand Up @@ -137,7 +139,7 @@ def __init__(
self.model = ort.InferenceSession(
model_path, providers=self.providers, sess_options=sess_options
)
self.logger.info("Using ONNX Runtime with %s.", self.providers[0])
self.logger.info("Using ONNX Runtime with %s.", self.providers)

def benchmark(self, n_iter: int = 10_000, include_processing: bool = False) -> None:
"""
Expand Down Expand Up @@ -165,12 +167,20 @@ def benchmark(self, n_iter: int = 10_000, include_processing: bool = False) -> N
avg_time = (cum_time / n_iter) if n_iter > 0 else 0.0
avg_pps = (1_000 / avg_time) if n_iter > 0 else 0.0

table = Table(title=f"Benchmark '{self.model_name}' model")
table.add_column("Executor", justify="center", style="cyan", no_wrap=True)
table.add_column("Average ms", style="magenta", justify="center")
table.add_column("Plates/second", style="magenta", justify="center")
table.add_row(self.providers[0], f"{avg_time:.4f}", f"{avg_pps:.4f}")
console = Console()
model_info = Panel(
Text(f"Model: {self.model_name}\nProviders: {self.providers}", style="bold green"),
title="Model Information",
border_style="bright_blue",
expand=False,
)
console.print(model_info)
table = Table(title=f"Benchmark for '{self.model_name}' Model", border_style="bright_blue")
table.add_column("Metric", justify="center", style="cyan", no_wrap=True)
table.add_column("Value", justify="center", style="magenta")
table.add_row("Number of Iterations", str(n_iter))
table.add_row("Average Time (ms)", f"{avg_time:.4f}")
table.add_row("Plates Per Second (PPS)", f"{avg_pps:.4f}")
console.print(table)

def run(
Expand Down

0 comments on commit 83e8b3a

Please sign in to comment.