Skip to content

Commit 73e0781

Browse files
committed
Accept more variable inference input
1 parent 85dc3d5 commit 73e0781

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

fast_plate_ocr/inference/onnx_inference.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
from contextlib import nullcontext
77

8+
import numpy as np
89
import numpy.typing as npt
910
import onnxruntime as ort
1011

@@ -16,6 +17,28 @@
1617
logging.basicConfig(level=logging.INFO)
1718

1819

20+
def _load_image_from_source(source: str | list[str] | npt.NDArray) -> npt.NDArray:
21+
"""
22+
Loads an image from a given source.
23+
24+
:param source: Path to the input image file or numpy array representing the image.
25+
:return: Numpy array representing the input image.
26+
"""
27+
if isinstance(source, str):
28+
return read_plate_image(source)
29+
30+
if isinstance(source, list) and isinstance(source[0], str):
31+
return np.array([read_plate_image(i) for i in source])
32+
33+
if isinstance(source, np.ndarray):
34+
if source.ndim > 3:
35+
raise ValueError("Expected source to be of shape (H, W, 1) or (H, W) or (1, H, W, 1)")
36+
source = source.squeeze()
37+
return source
38+
39+
raise ValueError("Unsupported input type. Only file path or numpy array is supported.")
40+
41+
1942
class ONNXInference:
2043
"""
2144
ONNX inference class for performing license plates OCR.
@@ -45,17 +68,17 @@ def __init__(self, ocr_model: str, use_gpu: bool = False, log_time: bool = False
4568

4669
def run(
4770
self,
48-
image_path: str,
71+
source: str | list[str] | npt.NDArray,
4972
return_confidence: bool = False,
5073
) -> tuple[list[str], npt.NDArray] | list[str]:
5174
"""
5275
Runs inference on an image.
5376
54-
:param image_path: Path to the input image file.
77+
:param source: Path to the input image file or numpy array representing the image.
5578
:param return_confidence: Whether to return confidence scores along with plate predictions.
5679
:return: Decoded license plate characters as a list.
5780
"""
58-
x = read_plate_image(image_path)
81+
x = _load_image_from_source(source)
5982
with log_time_taken("Pre-process") if self.log_time else nullcontext():
6083
x = preprocess_image(x, self.config["img_height"], self.config["img_width"])
6184
with log_time_taken("Model run") if self.log_time else nullcontext():

0 commit comments

Comments
 (0)