-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
69 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
ONNX inference module. | ||
""" | ||
|
||
import logging | ||
from contextlib import nullcontext | ||
|
||
import numpy.typing as npt | ||
import onnxruntime as ort | ||
|
||
import fast_plate_ocr.inference.config | ||
from fast_plate_ocr.common.utils import log_time_taken | ||
from fast_plate_ocr.inference import hub | ||
from fast_plate_ocr.inference.process import postprocess_output, preprocess_image, read_plate_image | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class ONNXInference: | ||
""" | ||
ONNX inference class for performing license plates OCR. | ||
""" | ||
|
||
def __init__(self, ocr_model: str, use_gpu: bool = False, log_time: bool = False): | ||
""" | ||
Initialize ONNXInference. | ||
:param ocr_model: Name of the OCR model to use. | ||
:param use_gpu: Flag indicating whether to use GPU backend. | ||
""" | ||
self.logger = logging.getLogger(__name__) | ||
self.log_time = log_time | ||
|
||
if use_gpu: | ||
self.providers = ["CUDAExecutionProvider"] | ||
self.device = "GPU" | ||
else: | ||
self.providers = ["CPUExecutionProvider"] | ||
self.device = "CPU" | ||
|
||
model_path, config_path = hub.download_model(model_name=ocr_model) | ||
self.config = fast_plate_ocr.inference.config.load_config_from_yaml(config_path) | ||
self.model = ort.InferenceSession(model_path, providers=self.providers) | ||
self.logger.info("Using ONNX Runtime with %s device.", self.device) | ||
|
||
def run( | ||
self, | ||
image_path: str, | ||
return_confidence: bool = False, | ||
) -> tuple[list[str], npt.NDArray] | list[str]: | ||
""" | ||
Runs inference on an image. | ||
:param image_path: Path to the input image file. | ||
:param return_confidence: Whether to return confidence scores along with plate predictions. | ||
:return: Decoded license plate characters as a list. | ||
""" | ||
x = read_plate_image(image_path) | ||
with log_time_taken("Pre-process") if self.log_time else nullcontext(): | ||
x = preprocess_image(x, self.config["img_height"], self.config["img_width"]) | ||
with log_time_taken("Model run") if self.log_time else nullcontext(): | ||
y: list[npt.NDArray] = self.model.run(None, {"input": x}) | ||
with log_time_taken("Post-process") if self.log_time else nullcontext(): | ||
return postprocess_output( | ||
y[0], | ||
self.config["max_plate_slots"], | ||
self.config["alphabet"], | ||
return_confidence=return_confidence, | ||
) |