Skip to content

Commit

Permalink
Update tf_test.py to use jdl.ArrayDataset and 'tensorflow' backend
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Feb 15, 2024
1 parent c31ad48 commit dd99bea
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions integration_tests/tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


def test_jax():
ds = tf.data.Dataset.from_tensor_slices((tf.ones((10, 3)), tf.ones((10, 3))))
dl = jdl.DataLoader(ds, 'jax', batch_size=2)
ds = jdl.ArrayDataset(np.ones((10, 3)), np.ones((10, 3)))
dl = jdl.DataLoader(ds, 'tensorflow', batch_size=2)
for x, y in dl:
z = x + y
assert isinstance(z, np.ndarray)
Expand Down

0 comments on commit dd99bea

Please sign in to comment.