@@ -1219,75 +1219,6 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
1219
1219
)
1220
1220
self .assertListEqual (hist_keys , ref_keys )
1221
1221
1222
- @parameterized .named_parameters (
1223
- ("tf_saved_model" , "tf_saved_model" ),
1224
- ("onnx" , "onnx" ),
1225
- )
1226
- @pytest .mark .skipif (
1227
- backend .backend () not in ("tensorflow" , "jax" , "torch" ),
1228
- reason = (
1229
- "Currently, `Model.export` only supports the tensorflow, jax and "
1230
- "torch backends."
1231
- ),
1232
- )
1233
- @pytest .mark .skipif (
1234
- testing .jax_uses_gpu (), reason = "Leads to core dumps on CI"
1235
- )
1236
- def test_export (self , export_format ):
1237
- if export_format == "tf_saved_model" and testing .torch_uses_gpu ():
1238
- self .skipTest ("Leads to core dumps on CI" )
1239
-
1240
- temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
1241
- model = _get_model ()
1242
- x1 = np .random .rand (1 , 3 ).astype ("float32" )
1243
- x2 = np .random .rand (1 , 3 ).astype ("float32" )
1244
- ref_output = model ([x1 , x2 ])
1245
-
1246
- model .export (temp_filepath , format = export_format )
1247
-
1248
- if export_format == "tf_saved_model" :
1249
- import tensorflow as tf
1250
-
1251
- revived_model = tf .saved_model .load (temp_filepath )
1252
- self .assertAllClose (ref_output , revived_model .serve ([x1 , x2 ]))
1253
-
1254
- # Test with a different batch size
1255
- if backend .backend () == "torch" :
1256
- # TODO: Dynamic shape is not supported yet in the torch backend
1257
- return
1258
- revived_model .serve (
1259
- [
1260
- np .concatenate ([x1 , x1 ], axis = 0 ),
1261
- np .concatenate ([x2 , x2 ], axis = 0 ),
1262
- ]
1263
- )
1264
- elif export_format == "onnx" :
1265
- import onnxruntime
1266
-
1267
- ort_session = onnxruntime .InferenceSession (temp_filepath )
1268
- ort_inputs = {
1269
- k .name : v for k , v in zip (ort_session .get_inputs (), [x1 , x2 ])
1270
- }
1271
- self .assertAllClose (
1272
- ref_output , ort_session .run (None , ort_inputs )[0 ]
1273
- )
1274
-
1275
- # Test with a different batch size
1276
- if backend .backend () == "torch" :
1277
- # TODO: Dynamic shape is not supported yet in the torch backend
1278
- return
1279
- ort_inputs = {
1280
- k .name : v
1281
- for k , v in zip (
1282
- ort_session .get_inputs (),
1283
- [
1284
- np .concatenate ([x1 , x1 ], axis = 0 ),
1285
- np .concatenate ([x2 , x2 ], axis = 0 ),
1286
- ],
1287
- )
1288
- }
1289
- ort_session .run (None , ort_inputs )
1290
-
1291
1222
def test_export_error (self ):
1292
1223
temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
1293
1224
model = _get_model ()
0 commit comments