Skip to content

Commit fd2955f

Browse files
authored
Remove duplicate export tests in model_test. (#20735)
The same tests exist at: - https://github.com/keras-team/keras/blob/master/keras/src/export/saved_model_test.py#L66 - https://github.com/keras-team/keras/blob/master/keras/src/export/onnx_test.py#L62 The goal is to isolate the use of `onnxruntime` to a single file, `onnx_test.py`.
1 parent f97be63 commit fd2955f

File tree

1 file changed

+0
-69
lines changed

1 file changed

+0
-69
lines changed

keras/src/models/model_test.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,75 +1219,6 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
12191219
)
12201220
self.assertListEqual(hist_keys, ref_keys)
12211221

1222-
@parameterized.named_parameters(
1223-
("tf_saved_model", "tf_saved_model"),
1224-
("onnx", "onnx"),
1225-
)
1226-
@pytest.mark.skipif(
1227-
backend.backend() not in ("tensorflow", "jax", "torch"),
1228-
reason=(
1229-
"Currently, `Model.export` only supports the tensorflow, jax and "
1230-
"torch backends."
1231-
),
1232-
)
1233-
@pytest.mark.skipif(
1234-
testing.jax_uses_gpu(), reason="Leads to core dumps on CI"
1235-
)
1236-
def test_export(self, export_format):
1237-
if export_format == "tf_saved_model" and testing.torch_uses_gpu():
1238-
self.skipTest("Leads to core dumps on CI")
1239-
1240-
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
1241-
model = _get_model()
1242-
x1 = np.random.rand(1, 3).astype("float32")
1243-
x2 = np.random.rand(1, 3).astype("float32")
1244-
ref_output = model([x1, x2])
1245-
1246-
model.export(temp_filepath, format=export_format)
1247-
1248-
if export_format == "tf_saved_model":
1249-
import tensorflow as tf
1250-
1251-
revived_model = tf.saved_model.load(temp_filepath)
1252-
self.assertAllClose(ref_output, revived_model.serve([x1, x2]))
1253-
1254-
# Test with a different batch size
1255-
if backend.backend() == "torch":
1256-
# TODO: Dynamic shape is not supported yet in the torch backend
1257-
return
1258-
revived_model.serve(
1259-
[
1260-
np.concatenate([x1, x1], axis=0),
1261-
np.concatenate([x2, x2], axis=0),
1262-
]
1263-
)
1264-
elif export_format == "onnx":
1265-
import onnxruntime
1266-
1267-
ort_session = onnxruntime.InferenceSession(temp_filepath)
1268-
ort_inputs = {
1269-
k.name: v for k, v in zip(ort_session.get_inputs(), [x1, x2])
1270-
}
1271-
self.assertAllClose(
1272-
ref_output, ort_session.run(None, ort_inputs)[0]
1273-
)
1274-
1275-
# Test with a different batch size
1276-
if backend.backend() == "torch":
1277-
# TODO: Dynamic shape is not supported yet in the torch backend
1278-
return
1279-
ort_inputs = {
1280-
k.name: v
1281-
for k, v in zip(
1282-
ort_session.get_inputs(),
1283-
[
1284-
np.concatenate([x1, x1], axis=0),
1285-
np.concatenate([x2, x2], axis=0),
1286-
],
1287-
)
1288-
}
1289-
ort_session.run(None, ort_inputs)
1290-
12911222
def test_export_error(self):
12921223
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
12931224
model = _get_model()

0 commit comments

Comments
 (0)