-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.py
171 lines (137 loc) · 6.75 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import time
import argparse
import params
from detection.split_image import calOneImage
from recognition.load_model import load_model
from recognition.recognize_word import recognize_word_dir, postprocess, write_xml
"""
date: 2020-12-18
author@kxie
input: image path
output: ocr result (xml)
usage: python inference.py -i [image path/dir] -o [output save dir]
e.g.
python inference.py -i demo/img/lrb000.png -o demo/xml (single image inference)
python inference.py -i demo/img -o demo/xml (batch images inference)
"""
# solve problem (macOS):
# OMP: Error #15: Initializing libiomp5.dylib, but found libiomp5.dylib already initialized.
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
def inference(input_path, output_dir):
"""single image inference: input an image, output ocr result in xml format.
params:
- input_path: input image path
- output_dir: xml save dir
"""
img_name = os.path.basename(input_path).split('.')[0]
# detection
time_det_start = time.time()
word_img_dir = os.path.abspath(os.path.join(params.word_dir, img_name))
calOneImage(input_path, word_img_dir, output_dir)
time_det_end = time.time()
time_det_elapse = time_det_end - time_det_start
# recognition
if os.path.isdir(word_img_dir) and len(os.listdir(word_img_dir)) > 0:
time_rec_start = time.time()
trained_model, device = load_model(params.model_arch, params.model_path)
pred_dict = recognize_word_dir(word_img_dir, trained_model, device)
pred_dict = postprocess(pred_dict)
xml_path = os.path.abspath(os.path.join(output_dir, img_name + '.xml'))
write_xml(pred_dict, xml_path)
time_rec_end = time.time()
time_rec_elapse = time_rec_end - time_rec_start
print("{} - [INFO] - Success. "
"{} word bbx detected in {}. "
"det.time: {:.2f}s, "
"rec.time: {:.2f}s, "
"total {:.2f}s. "
"(with {})"
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
len(os.listdir(word_img_dir)),
os.path.basename(input_path),
time_det_elapse,
time_rec_elapse,
time_det_elapse + time_rec_elapse,
device))
print("\n{} - [INFO] - Detected word images saved in {}"
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), word_img_dir))
print("{} - [INFO] - XML file saved in {}"
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), os.path.abspath(xml_path)))
else:
print("{} - [ERROR] - {}. No such directory or this directory is empty."
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), word_img_dir))
def inference_batch(input_dir, output_dir):
"""batch images inference: input multiple images, output ocr results in xml format.
Args:
- input_dir (str): input image dir
- output_dir (str): xml save dir
"""
count = 0
count_all = 0
time_elapse = 0
image_format = [".jpg", ".jpeg", ".bmp", ".png"]
file_list = os.listdir(input_dir)
file_list.sort()
# load model
trained_model, device = load_model(params.model_arch, params.model_path)
for i in range(0, len(file_list)):
print("{} - [INFO] - {}/{}: "
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), i + 1, len(file_list)), end='')
input_path = os.path.join(input_dir, file_list[i])
if os.path.isfile(input_path) and os.path.splitext(input_path)[1] in image_format:
count_all = count_all + 1
time_start = time.time()
img_name = os.path.splitext(file_list[i])[0]
word_img_dir = os.path.abspath(os.path.join(params.word_dir, img_name))
try:
# detection
calOneImage(input_path, word_img_dir, output_dir)
# recognition
if os.path.isdir(word_img_dir) and len(os.listdir(word_img_dir)) > 0:
pred_dict = recognize_word_dir(word_img_dir, trained_model, device)
pred_dict = postprocess(pred_dict)
xml_path = os.path.abspath(os.path.join(output_dir, img_name + '.xml'))
write_xml(pred_dict, xml_path)
# if success
time_end = time.time()
time_elapse = time_elapse + (time_end - time_start)
count = count + 1
print("{} word bbx detected in {}. total {:.2f}s."
.format(len(os.listdir(word_img_dir)), file_list[i], time_end - time_start))
else:
print("{} - [ERROR] - {}. No such directory or this directory is empty."
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), word_img_dir))
except:
print("{} - [Error] - {}, inference failed, pass."
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file_list[i]))
continue
else:
print("{} - [Warning] - {}, pass."
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), file_list[i]))
continue
print("{} - [INFO] - Done. {}/{} success. total: {:.2f}s, avg.: {:.2f}s (with {})"
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
count,
count_all,
time_elapse,
time_elapse / count,
device))
print("{} - [INFO] - Detected word images saved in {}"
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), params.word_dir))
print("{} - [INFO] - XML saved in {}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), output_dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument('-i', '--input', required=True, help='path to input image')
# parser.add_argument('-o', '--output', required=True, help='path to output xml file')
args = parser.parse_args()
args.input = '/home/pudding/data/project/SheetOCR/demo/img/lrb_000.png'
args.output = '/home/pudding/data/project/SheetOCR/demo/xml/lrb_000.xml'
if not os.path.exists(args.output):
os.makedirs(args.output)
if os.path.isfile(os.path.abspath(args.input)):
inference(os.path.abspath(args.input), os.path.abspath(args.output))
elif os.path.isdir(args.input):
inference_batch(os.path.abspath(args.input), os.path.abspath(args.output))
else:
raise ValueError("invalid input or output")