Skip to content

Commit fab303a

Browse files
committed
Move backend utils to proper module to avoid keras being loaded first
1 parent 76df538 commit fab303a

File tree

3 files changed

+28
-23
lines changed

3 files changed

+28
-23
lines changed

fast_plate_ocr/backend_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
Utils for Keras supported backends.
3+
"""
4+
5+
import os
6+
7+
from fast_plate_ocr.custom_types import Framework
8+
9+
10+
def set_jax_backend() -> None:
11+
"""Set Keras backend to jax."""
12+
set_keras_backend("jax")
13+
14+
15+
def set_tensorflow_backend() -> None:
16+
"""Set Keras backend to tensorflow."""
17+
set_keras_backend("tensorflow")
18+
19+
20+
def set_pytorch_backend() -> None:
21+
"""Set Keras backend to pytorch."""
22+
set_keras_backend("torch")
23+
24+
25+
def set_keras_backend(framework: Framework) -> None:
26+
"""Set the Keras backend to a given framework."""
27+
os.environ["KERAS_BACKEND"] = framework

fast_plate_ocr/utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import logging
6-
import os
76
import pathlib
87
import random
98
import time
@@ -25,7 +24,6 @@
2524
VOCABULARY_SIZE,
2625
)
2726
from fast_plate_ocr.custom import cat_acc_metric, cce_loss, plate_acc_metric, top_3_k_metric
28-
from fast_plate_ocr.custom_types import Framework
2927

3028

3129
def one_hot_plate(plate: str, alphabet: str = MODEL_ALPHABET) -> list[list[int]]:
@@ -45,26 +43,6 @@ def target_transform(
4543
return encoded_plate
4644

4745

48-
def set_tensorflow_backend() -> None:
49-
"""Set Keras backend to tensorflow."""
50-
set_keras_backend("tensorflow")
51-
52-
53-
def set_jax_backend() -> None:
54-
"""Set Keras backend to jax."""
55-
set_keras_backend("jax")
56-
57-
58-
def set_pytorch_backend() -> None:
59-
"""Set Keras backend to pytorch."""
60-
set_keras_backend("torch")
61-
62-
63-
def set_keras_backend(framework: Framework) -> None:
64-
"""Set the Keras backend to a given framework."""
65-
os.environ["KERAS_BACKEND"] = framework
66-
67-
6846
def read_plate_image(
6947
image_path: str, img_height: int = DEFAULT_IMG_HEIGHT, img_width: int = DEFAULT_IMG_WIDTH
7048
) -> npt.NDArray:

test/fast_lp_ocr/test_custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# ruff: noqa: E402
66
# pylint: disable=wrong-import-position,wrong-import-order,ungrouped-imports
77
# fmt: off
8-
from fast_plate_ocr.utils import set_pytorch_backend
8+
from fast_plate_ocr.backend_utils import set_pytorch_backend
99

1010
set_pytorch_backend()
1111
# fmt: on

0 commit comments

Comments
 (0)