Skip to content

Commit 520fffc

Browse files
committed
Improve ONNX allclose message (#12)
1 parent 4dcd032 commit 520fffc

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

fast_plate_ocr/cli/onnx_converter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def export_onnx(
7474
)
7575
output_names = [n.name for n in model_proto.graph.output]
7676
x = np.random.randint(0, 256, size=(1, config.img_height, config.img_width, 1), dtype=np.uint8)
77-
providers = ["CPUExecutionProvider"]
7877
# Run dummy inference and log time taken
79-
m = rt.InferenceSession(output_path, providers=providers)
78+
m = rt.InferenceSession(output_path)
8079
with log_time_taken("ONNX inference took:"):
8180
onnx_pred = m.run(output_names, {"input": x})
82-
# Check ONNX and keras have the same results
83-
np.testing.assert_allclose(model.predict(x, verbose=0), onnx_pred[0], rtol=1e-5)
84-
logging.info("Model converted successfully to ONNX! Saved at %s", output_path)
81+
# Check if ONNX and keras have the same results
82+
if not np.allclose(model.predict(x, verbose=0), onnx_pred[0], rtol=1e-5, atol=1e-5):
83+
logging.warning("ONNX model output was not close to Keras model for the given tolerance!")
84+
logging.info("Model converted to ONNX! Saved at %s", output_path)
8585

8686

8787
if __name__ == "__main__":

0 commit comments

Comments
 (0)