|
5 | 5 | import logging
|
6 | 6 | from contextlib import nullcontext
|
7 | 7 |
|
| 8 | +import numpy as np |
8 | 9 | import numpy.typing as npt
|
9 | 10 | import onnxruntime as ort
|
10 | 11 |
|
|
16 | 17 | logging.basicConfig(level=logging.INFO)
|
17 | 18 |
|
18 | 19 |
|
| 20 | +def _load_image_from_source(source: str | list[str] | npt.NDArray) -> npt.NDArray: |
| 21 | + """ |
| 22 | + Loads an image from a given source. |
| 23 | +
|
| 24 | + :param source: Path to the input image file or numpy array representing the image. |
| 25 | + :return: Numpy array representing the input image. |
| 26 | + """ |
| 27 | + if isinstance(source, str): |
| 28 | + return read_plate_image(source) |
| 29 | + |
| 30 | + if isinstance(source, list) and isinstance(source[0], str): |
| 31 | + return np.array([read_plate_image(i) for i in source]) |
| 32 | + |
| 33 | + if isinstance(source, np.ndarray): |
| 34 | + if source.ndim > 3: |
| 35 | + raise ValueError("Expected source to be of shape (H, W, 1) or (H, W) or (1, H, W, 1)") |
| 36 | + source = source.squeeze() |
| 37 | + return source |
| 38 | + |
| 39 | + raise ValueError("Unsupported input type. Only file path or numpy array is supported.") |
| 40 | + |
| 41 | + |
19 | 42 | class ONNXInference:
|
20 | 43 | """
|
21 | 44 | ONNX inference class for performing license plates OCR.
|
@@ -45,17 +68,17 @@ def __init__(self, ocr_model: str, use_gpu: bool = False, log_time: bool = False
|
45 | 68 |
|
46 | 69 | def run(
|
47 | 70 | self,
|
48 |
| - image_path: str, |
| 71 | + source: str | list[str] | npt.NDArray, |
49 | 72 | return_confidence: bool = False,
|
50 | 73 | ) -> tuple[list[str], npt.NDArray] | list[str]:
|
51 | 74 | """
|
52 | 75 | Runs inference on an image.
|
53 | 76 |
|
54 |
| - :param image_path: Path to the input image file. |
| 77 | + :param source: Path to the input image file or numpy array representing the image. |
55 | 78 | :param return_confidence: Whether to return confidence scores along with plate predictions.
|
56 | 79 | :return: Decoded license plate characters as a list.
|
57 | 80 | """
|
58 |
| - x = read_plate_image(image_path) |
| 81 | + x = _load_image_from_source(source) |
59 | 82 | with log_time_taken("Pre-process") if self.log_time else nullcontext():
|
60 | 83 | x = preprocess_image(x, self.config["img_height"], self.config["img_width"])
|
61 | 84 | with log_time_taken("Model run") if self.log_time else nullcontext():
|
|
0 commit comments